forked from open-webui/open-webui
Merge pull request #1554 from open-webui/external-embeddings
feat: external embeddings
This commit is contained in:
commit
54a4b7db14
6 changed files with 288 additions and 101 deletions
|
@ -659,7 +659,7 @@ def generate_ollama_embeddings(
|
||||||
url_idx: Optional[int] = None,
|
url_idx: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
log.info("generate_ollama_embeddings", form_data)
|
log.info(f"generate_ollama_embeddings {form_data}")
|
||||||
|
|
||||||
if url_idx == None:
|
if url_idx == None:
|
||||||
model = form_data.model
|
model = form_data.model
|
||||||
|
@ -688,7 +688,7 @@ def generate_ollama_embeddings(
|
||||||
|
|
||||||
data = r.json()
|
data = r.json()
|
||||||
|
|
||||||
log.info("generate_ollama_embeddings", data)
|
log.info(f"generate_ollama_embeddings {data}")
|
||||||
|
|
||||||
if "embedding" in data:
|
if "embedding" in data:
|
||||||
return data["embedding"]
|
return data["embedding"]
|
||||||
|
|
|
@ -53,6 +53,7 @@ from apps.rag.utils import (
|
||||||
query_collection,
|
query_collection,
|
||||||
query_embeddings_collection,
|
query_embeddings_collection,
|
||||||
get_embedding_model_path,
|
get_embedding_model_path,
|
||||||
|
generate_openai_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils.misc import (
|
from utils.misc import (
|
||||||
|
@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||||
|
|
||||||
|
app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
|
||||||
|
app.state.RAG_OPENAI_API_KEY = ""
|
||||||
|
|
||||||
app.state.PDF_EXTRACT_IMAGES = False
|
app.state.PDF_EXTRACT_IMAGES = False
|
||||||
|
|
||||||
|
@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
||||||
"status": True,
|
"status": True,
|
||||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
||||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
"openai_config": {
|
||||||
|
"url": app.state.RAG_OPENAI_API_BASE_URL,
|
||||||
|
"key": app.state.RAG_OPENAI_API_KEY,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIConfigForm(BaseModel):
|
||||||
|
url: str
|
||||||
|
key: str
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModelUpdateForm(BaseModel):
|
class EmbeddingModelUpdateForm(BaseModel):
|
||||||
|
openai_config: Optional[OpenAIConfigForm] = None
|
||||||
embedding_engine: str
|
embedding_engine: str
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
|
|
||||||
|
@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel):
|
||||||
async def update_embedding_config(
|
async def update_embedding_config(
|
||||||
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||||
|
|
||||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
||||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||||
app.state.sentence_transformer_ef = None
|
app.state.sentence_transformer_ef = None
|
||||||
|
|
||||||
|
if form_data.openai_config != None:
|
||||||
|
app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||||
|
app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
|
||||||
else:
|
else:
|
||||||
sentence_transformer_ef = (
|
sentence_transformer_ef = (
|
||||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||||
|
@ -183,6 +198,10 @@ async def update_embedding_config(
|
||||||
"status": True,
|
"status": True,
|
||||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
||||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
"openai_config": {
|
||||||
|
"url": app.state.RAG_OPENAI_API_BASE_URL,
|
||||||
|
"key": app.state.RAG_OPENAI_API_KEY,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -275,6 +294,14 @@ def query_doc_handler(
|
||||||
):
|
):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||||
|
return query_doc(
|
||||||
|
collection_name=form_data.collection_name,
|
||||||
|
query=form_data.query,
|
||||||
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
|
embedding_function=app.state.sentence_transformer_ef,
|
||||||
|
)
|
||||||
|
else:
|
||||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||||
query_embeddings = generate_ollama_embeddings(
|
query_embeddings = generate_ollama_embeddings(
|
||||||
GenerateEmbeddingsForm(
|
GenerateEmbeddingsForm(
|
||||||
|
@ -284,19 +311,20 @@ def query_doc_handler(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||||
|
query_embeddings = generate_openai_embeddings(
|
||||||
|
model=app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
text=form_data.query,
|
||||||
|
key=app.state.RAG_OPENAI_API_KEY,
|
||||||
|
url=app.state.RAG_OPENAI_API_BASE_URL,
|
||||||
|
)
|
||||||
|
|
||||||
return query_embeddings_doc(
|
return query_embeddings_doc(
|
||||||
collection_name=form_data.collection_name,
|
collection_name=form_data.collection_name,
|
||||||
query_embeddings=query_embeddings,
|
query_embeddings=query_embeddings,
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return query_doc(
|
|
||||||
collection_name=form_data.collection_name,
|
|
||||||
query=form_data.query,
|
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
|
||||||
embedding_function=app.state.sentence_transformer_ef,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -317,6 +345,15 @@ def query_collection_handler(
|
||||||
user=Depends(get_current_user),
|
user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||||
|
return query_collection(
|
||||||
|
collection_names=form_data.collection_names,
|
||||||
|
query=form_data.query,
|
||||||
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
|
embedding_function=app.state.sentence_transformer_ef,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||||
query_embeddings = generate_ollama_embeddings(
|
query_embeddings = generate_ollama_embeddings(
|
||||||
GenerateEmbeddingsForm(
|
GenerateEmbeddingsForm(
|
||||||
|
@ -326,19 +363,20 @@ def query_collection_handler(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||||
|
query_embeddings = generate_openai_embeddings(
|
||||||
|
model=app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
text=form_data.query,
|
||||||
|
key=app.state.RAG_OPENAI_API_KEY,
|
||||||
|
url=app.state.RAG_OPENAI_API_BASE_URL,
|
||||||
|
)
|
||||||
|
|
||||||
return query_embeddings_collection(
|
return query_embeddings_collection(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
query_embeddings=query_embeddings,
|
query_embeddings=query_embeddings,
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return query_collection(
|
|
||||||
collection_names=form_data.collection_names,
|
|
||||||
query=form_data.query,
|
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
|
||||||
embedding_function=app.state.sentence_transformer_ef,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -383,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
|
||||||
docs = text_splitter.split_documents(data)
|
docs = text_splitter.split_documents(data)
|
||||||
|
|
||||||
if len(docs) > 0:
|
if len(docs) > 0:
|
||||||
log.info("store_data_in_vector_db", "store_docs_in_vector_db")
|
log.info(f"store_data_in_vector_db {docs}")
|
||||||
return store_docs_in_vector_db(docs, collection_name, overwrite), None
|
return store_docs_in_vector_db(docs, collection_name, overwrite), None
|
||||||
else:
|
else:
|
||||||
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
|
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
|
||||||
|
@ -402,7 +440,7 @@ def store_text_in_vector_db(
|
||||||
|
|
||||||
|
|
||||||
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
|
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
|
||||||
log.info("store_docs_in_vector_db", docs, collection_name)
|
log.info(f"store_docs_in_vector_db {docs} {collection_name}")
|
||||||
|
|
||||||
texts = [doc.page_content for doc in docs]
|
texts = [doc.page_content for doc in docs]
|
||||||
metadatas = [doc.metadata for doc in docs]
|
metadatas = [doc.metadata for doc in docs]
|
||||||
|
@ -414,24 +452,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
||||||
log.info(f"deleting existing collection {collection_name}")
|
log.info(f"deleting existing collection {collection_name}")
|
||||||
CHROMA_CLIENT.delete_collection(name=collection_name)
|
CHROMA_CLIENT.delete_collection(name=collection_name)
|
||||||
|
|
||||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
|
||||||
|
|
||||||
for batch in create_batches(
|
|
||||||
api=CHROMA_CLIENT,
|
|
||||||
ids=[str(uuid.uuid1()) for _ in texts],
|
|
||||||
metadatas=metadatas,
|
|
||||||
embeddings=[
|
|
||||||
generate_ollama_embeddings(
|
|
||||||
GenerateEmbeddingsForm(
|
|
||||||
**{"model": RAG_EMBEDDING_MODEL, "prompt": text}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for text in texts
|
|
||||||
],
|
|
||||||
):
|
|
||||||
collection.add(*batch)
|
|
||||||
else:
|
|
||||||
|
|
||||||
collection = CHROMA_CLIENT.create_collection(
|
collection = CHROMA_CLIENT.create_collection(
|
||||||
name=collection_name,
|
name=collection_name,
|
||||||
|
@ -446,6 +467,38 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
||||||
):
|
):
|
||||||
collection.add(*batch)
|
collection.add(*batch)
|
||||||
|
|
||||||
|
else:
|
||||||
|
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||||
|
|
||||||
|
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||||
|
embeddings = [
|
||||||
|
generate_ollama_embeddings(
|
||||||
|
GenerateEmbeddingsForm(
|
||||||
|
**{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||||
|
embeddings = [
|
||||||
|
generate_openai_embeddings(
|
||||||
|
model=app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
text=text,
|
||||||
|
key=app.state.RAG_OPENAI_API_KEY,
|
||||||
|
url=app.state.RAG_OPENAI_API_BASE_URL,
|
||||||
|
)
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
|
||||||
|
for batch in create_batches(
|
||||||
|
api=CHROMA_CLIENT,
|
||||||
|
ids=[str(uuid.uuid1()) for _ in texts],
|
||||||
|
metadatas=metadatas,
|
||||||
|
embeddings=embeddings,
|
||||||
|
documents=texts,
|
||||||
|
):
|
||||||
|
collection.add(*batch)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
|
@ -6,9 +6,12 @@ import requests
|
||||||
|
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
|
||||||
|
|
||||||
|
|
||||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
@ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
|
||||||
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
|
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
|
||||||
try:
|
try:
|
||||||
# if you use docker use the model from the environment variable
|
# if you use docker use the model from the environment variable
|
||||||
log.info("query_embeddings_doc", query_embeddings)
|
log.info(f"query_embeddings_doc {query_embeddings}")
|
||||||
collection = CHROMA_CLIENT.get_collection(
|
collection = CHROMA_CLIENT.get_collection(
|
||||||
name=collection_name,
|
name=collection_name,
|
||||||
)
|
)
|
||||||
|
@ -40,6 +43,8 @@ def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
|
||||||
query_embeddings=[query_embeddings],
|
query_embeddings=[query_embeddings],
|
||||||
n_results=k,
|
n_results=k,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log.info(f"query_embeddings_doc:result {result}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -118,7 +123,7 @@ def query_collection(
|
||||||
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
|
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
log.info("query_embeddings_collection", query_embeddings)
|
log.info(f"query_embeddings_collection {query_embeddings}")
|
||||||
|
|
||||||
for collection_name in collection_names:
|
for collection_name in collection_names:
|
||||||
try:
|
try:
|
||||||
|
@ -141,8 +146,20 @@ def rag_template(template: str, context: str, query: str):
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
def rag_messages(docs, messages, template, k, embedding_function):
|
def rag_messages(
|
||||||
log.debug(f"docs: {docs}")
|
docs,
|
||||||
|
messages,
|
||||||
|
template,
|
||||||
|
k,
|
||||||
|
embedding_engine,
|
||||||
|
embedding_model,
|
||||||
|
embedding_function,
|
||||||
|
openai_key,
|
||||||
|
openai_url,
|
||||||
|
):
|
||||||
|
log.debug(
|
||||||
|
f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {openai_key} {openai_url}"
|
||||||
|
)
|
||||||
|
|
||||||
last_user_message_idx = None
|
last_user_message_idx = None
|
||||||
for i in range(len(messages) - 1, -1, -1):
|
for i in range(len(messages) - 1, -1, -1):
|
||||||
|
@ -175,6 +192,11 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
||||||
context = None
|
context = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
if doc["type"] == "text":
|
||||||
|
context = doc["content"]
|
||||||
|
else:
|
||||||
|
if embedding_engine == "":
|
||||||
if doc["type"] == "collection":
|
if doc["type"] == "collection":
|
||||||
context = query_collection(
|
context = query_collection(
|
||||||
collection_names=doc["collection_names"],
|
collection_names=doc["collection_names"],
|
||||||
|
@ -182,8 +204,6 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
||||||
k=k,
|
k=k,
|
||||||
embedding_function=embedding_function,
|
embedding_function=embedding_function,
|
||||||
)
|
)
|
||||||
elif doc["type"] == "text":
|
|
||||||
context = doc["content"]
|
|
||||||
else:
|
else:
|
||||||
context = query_doc(
|
context = query_doc(
|
||||||
collection_name=doc["collection_name"],
|
collection_name=doc["collection_name"],
|
||||||
|
@ -191,6 +211,38 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
||||||
k=k,
|
k=k,
|
||||||
embedding_function=embedding_function,
|
embedding_function=embedding_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if embedding_engine == "ollama":
|
||||||
|
query_embeddings = generate_ollama_embeddings(
|
||||||
|
GenerateEmbeddingsForm(
|
||||||
|
**{
|
||||||
|
"model": embedding_model,
|
||||||
|
"prompt": query,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif embedding_engine == "openai":
|
||||||
|
query_embeddings = generate_openai_embeddings(
|
||||||
|
model=embedding_model,
|
||||||
|
text=query,
|
||||||
|
key=openai_key,
|
||||||
|
url=openai_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
if doc["type"] == "collection":
|
||||||
|
context = query_embeddings_collection(
|
||||||
|
collection_names=doc["collection_names"],
|
||||||
|
query_embeddings=query_embeddings,
|
||||||
|
k=k,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
context = query_embeddings_doc(
|
||||||
|
collection_name=doc["collection_name"],
|
||||||
|
query_embeddings=query_embeddings,
|
||||||
|
k=k,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
context = None
|
context = None
|
||||||
|
@ -269,3 +321,26 @@ def get_embedding_model_path(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Cannot determine embedding model snapshot path: {e}")
|
log.exception(f"Cannot determine embedding model snapshot path: {e}")
|
||||||
return embedding_model
|
return embedding_model
|
||||||
|
|
||||||
|
|
||||||
|
def generate_openai_embeddings(
|
||||||
|
model: str, text: str, key: str, url: str = "https://api.openai.com"
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
r = requests.post(
|
||||||
|
f"{url}/v1/embeddings",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
},
|
||||||
|
json={"input": text, "model": model},
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
if "data" in data:
|
||||||
|
return data["data"][0]["embedding"]
|
||||||
|
else:
|
||||||
|
raise "Something went wrong :/"
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return None
|
||||||
|
|
|
@ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||||
data["messages"],
|
data["messages"],
|
||||||
rag_app.state.RAG_TEMPLATE,
|
rag_app.state.RAG_TEMPLATE,
|
||||||
rag_app.state.TOP_K,
|
rag_app.state.TOP_K,
|
||||||
|
rag_app.state.RAG_EMBEDDING_ENGINE,
|
||||||
|
rag_app.state.RAG_EMBEDDING_MODEL,
|
||||||
rag_app.state.sentence_transformer_ef,
|
rag_app.state.sentence_transformer_ef,
|
||||||
|
rag_app.state.RAG_OPENAI_API_KEY,
|
||||||
|
rag_app.state.RAG_OPENAI_API_BASE_URL,
|
||||||
)
|
)
|
||||||
del data["docs"]
|
del data["docs"]
|
||||||
|
|
||||||
|
|
|
@ -373,7 +373,13 @@ export const getEmbeddingConfig = async (token: string) => {
|
||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type OpenAIConfigForm = {
|
||||||
|
key: string;
|
||||||
|
url: string;
|
||||||
|
};
|
||||||
|
|
||||||
type EmbeddingModelUpdateForm = {
|
type EmbeddingModelUpdateForm = {
|
||||||
|
openai_config?: OpenAIConfigForm;
|
||||||
embedding_engine: string;
|
embedding_engine: string;
|
||||||
embedding_model: string;
|
embedding_model: string;
|
||||||
};
|
};
|
||||||
|
|
|
@ -29,6 +29,9 @@
|
||||||
let embeddingEngine = '';
|
let embeddingEngine = '';
|
||||||
let embeddingModel = '';
|
let embeddingModel = '';
|
||||||
|
|
||||||
|
let openAIKey = '';
|
||||||
|
let openAIUrl = '';
|
||||||
|
|
||||||
let chunkSize = 0;
|
let chunkSize = 0;
|
||||||
let chunkOverlap = 0;
|
let chunkOverlap = 0;
|
||||||
let pdfExtractImages = true;
|
let pdfExtractImages = true;
|
||||||
|
@ -50,15 +53,6 @@
|
||||||
};
|
};
|
||||||
|
|
||||||
const embeddingModelUpdateHandler = async () => {
|
const embeddingModelUpdateHandler = async () => {
|
||||||
if (embeddingModel === '') {
|
|
||||||
toast.error(
|
|
||||||
$i18n.t(
|
|
||||||
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
|
|
||||||
)
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) {
|
if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) {
|
||||||
toast.error(
|
toast.error(
|
||||||
$i18n.t(
|
$i18n.t(
|
||||||
|
@ -67,21 +61,46 @@
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (embeddingEngine === 'ollama' && embeddingModel === '') {
|
||||||
|
toast.error(
|
||||||
|
$i18n.t(
|
||||||
|
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
|
||||||
|
)
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (embeddingEngine === 'openai' && embeddingModel === '') {
|
||||||
|
toast.error(
|
||||||
|
$i18n.t(
|
||||||
|
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
|
||||||
|
)
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((embeddingEngine === 'openai' && openAIKey === '') || openAIUrl === '') {
|
||||||
|
toast.error($i18n.t('OpenAI URL/Key required.'));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
console.log('Update embedding model attempt:', embeddingModel);
|
console.log('Update embedding model attempt:', embeddingModel);
|
||||||
|
|
||||||
updateEmbeddingModelLoading = true;
|
updateEmbeddingModelLoading = true;
|
||||||
const res = await updateEmbeddingConfig(localStorage.token, {
|
const res = await updateEmbeddingConfig(localStorage.token, {
|
||||||
embedding_engine: embeddingEngine,
|
embedding_engine: embeddingEngine,
|
||||||
embedding_model: embeddingModel
|
embedding_model: embeddingModel,
|
||||||
|
...(embeddingEngine === 'openai'
|
||||||
|
? {
|
||||||
|
openai_config: {
|
||||||
|
key: openAIKey,
|
||||||
|
url: openAIUrl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
: {})
|
||||||
}).catch(async (error) => {
|
}).catch(async (error) => {
|
||||||
toast.error(error);
|
toast.error(error);
|
||||||
|
await setEmbeddingConfig();
|
||||||
const embeddingConfig = await getEmbeddingConfig(localStorage.token);
|
|
||||||
if (embeddingConfig) {
|
|
||||||
embeddingEngine = embeddingConfig.embedding_engine;
|
|
||||||
embeddingModel = embeddingConfig.embedding_model;
|
|
||||||
}
|
|
||||||
return null;
|
return null;
|
||||||
});
|
});
|
||||||
updateEmbeddingModelLoading = false;
|
updateEmbeddingModelLoading = false;
|
||||||
|
@ -89,7 +108,7 @@
|
||||||
if (res) {
|
if (res) {
|
||||||
console.log('embeddingModelUpdateHandler:', res);
|
console.log('embeddingModelUpdateHandler:', res);
|
||||||
if (res.status === true) {
|
if (res.status === true) {
|
||||||
toast.success($i18n.t('Model {{embedding_model}} update complete!', res), {
|
toast.success($i18n.t('Embedding model set to "{{embedding_model}}"', res), {
|
||||||
duration: 1000 * 10
|
duration: 1000 * 10
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -107,6 +126,18 @@
|
||||||
querySettings = await updateQuerySettings(localStorage.token, querySettings);
|
querySettings = await updateQuerySettings(localStorage.token, querySettings);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const setEmbeddingConfig = async () => {
|
||||||
|
const embeddingConfig = await getEmbeddingConfig(localStorage.token);
|
||||||
|
|
||||||
|
if (embeddingConfig) {
|
||||||
|
embeddingEngine = embeddingConfig.embedding_engine;
|
||||||
|
embeddingModel = embeddingConfig.embedding_model;
|
||||||
|
|
||||||
|
openAIKey = embeddingConfig.openai_config.key;
|
||||||
|
openAIUrl = embeddingConfig.openai_config.url;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
onMount(async () => {
|
onMount(async () => {
|
||||||
const res = await getRAGConfig(localStorage.token);
|
const res = await getRAGConfig(localStorage.token);
|
||||||
|
|
||||||
|
@ -117,12 +148,7 @@
|
||||||
chunkOverlap = res.chunk.chunk_overlap;
|
chunkOverlap = res.chunk.chunk_overlap;
|
||||||
}
|
}
|
||||||
|
|
||||||
const embeddingConfig = await getEmbeddingConfig(localStorage.token);
|
await setEmbeddingConfig();
|
||||||
|
|
||||||
if (embeddingConfig) {
|
|
||||||
embeddingEngine = embeddingConfig.embedding_engine;
|
|
||||||
embeddingModel = embeddingConfig.embedding_model;
|
|
||||||
}
|
|
||||||
|
|
||||||
querySettings = await getQuerySettings(localStorage.token);
|
querySettings = await getQuerySettings(localStorage.token);
|
||||||
});
|
});
|
||||||
|
@ -146,15 +172,38 @@
|
||||||
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
|
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
|
||||||
bind:value={embeddingEngine}
|
bind:value={embeddingEngine}
|
||||||
placeholder="Select an embedding engine"
|
placeholder="Select an embedding engine"
|
||||||
on:change={() => {
|
on:change={(e) => {
|
||||||
|
if (e.target.value === 'ollama') {
|
||||||
embeddingModel = '';
|
embeddingModel = '';
|
||||||
|
} else if (e.target.value === 'openai') {
|
||||||
|
embeddingModel = 'text-embedding-3-small';
|
||||||
|
}
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<option value="">{$i18n.t('Default (SentenceTransformer)')}</option>
|
<option value="">{$i18n.t('Default (SentenceTransformer)')}</option>
|
||||||
<option value="ollama">{$i18n.t('Ollama')}</option>
|
<option value="ollama">{$i18n.t('Ollama')}</option>
|
||||||
|
<option value="openai">{$i18n.t('OpenAI')}</option>
|
||||||
</select>
|
</select>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{#if embeddingEngine === 'openai'}
|
||||||
|
<div class="mt-1 flex gap-2">
|
||||||
|
<input
|
||||||
|
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
|
||||||
|
placeholder={$i18n.t('API Base URL')}
|
||||||
|
bind:value={openAIUrl}
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
|
||||||
|
<input
|
||||||
|
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
|
||||||
|
placeholder={$i18n.t('API Key')}
|
||||||
|
bind:value={openAIKey}
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="space-y-2">
|
<div class="space-y-2">
|
||||||
|
|
Loading…
Reference in a new issue