fix: integration

This commit is contained in:
Timothy J. Baek 2024-04-14 18:47:45 -04:00
parent 9cdb5bf9fe
commit 36ce157907
3 changed files with 28 additions and 7 deletions

View file

@ -658,6 +658,9 @@ def generate_ollama_embeddings(
form_data: GenerateEmbeddingsForm, form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
): ):
log.info("generate_ollama_embeddings", form_data)
if url_idx == None: if url_idx == None:
model = form_data.model model = form_data.model
@ -685,6 +688,8 @@ def generate_ollama_embeddings(
data = r.json() data = r.json()
log.info("generate_ollama_embeddings", data)
if "embedding" in data: if "embedding" in data:
return data["embedding"] return data["embedding"]
else: else:

View file

@ -39,7 +39,7 @@ import uuid
import json import json
from apps.ollama.main import generate_ollama_embeddings from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
from apps.web.models.documents import ( from apps.web.models.documents import (
Documents, Documents,
@ -277,7 +277,12 @@ def query_doc_handler(
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "ollama": if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings( query_embeddings = generate_ollama_embeddings(
{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query} GenerateEmbeddingsForm(
**{
"model": app.state.RAG_EMBEDDING_MODEL,
"prompt": form_data.query,
}
)
) )
return query_embeddings_doc( return query_embeddings_doc(
@ -314,7 +319,12 @@ def query_collection_handler(
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "ollama": if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings( query_embeddings = generate_ollama_embeddings(
{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query} GenerateEmbeddingsForm(
**{
"model": app.state.RAG_EMBEDDING_MODEL,
"prompt": form_data.query,
}
)
) )
return query_embeddings_collection( return query_embeddings_collection(
@ -373,6 +383,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
docs = text_splitter.split_documents(data) docs = text_splitter.split_documents(data)
if len(docs) > 0: if len(docs) > 0:
log.info("store_data_in_vector_db", "store_docs_in_vector_db")
return store_docs_in_vector_db(docs, collection_name, overwrite), None return store_docs_in_vector_db(docs, collection_name, overwrite), None
else: else:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
@ -390,9 +401,8 @@ def store_text_in_vector_db(
return store_docs_in_vector_db(docs, collection_name, overwrite) return store_docs_in_vector_db(docs, collection_name, overwrite)
async def store_docs_in_vector_db( def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
docs, collection_name, overwrite: bool = False log.info("store_docs_in_vector_db", docs, collection_name)
) -> bool:
texts = [doc.page_content for doc in docs] texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
@ -413,13 +423,16 @@ async def store_docs_in_vector_db(
metadatas=metadatas, metadatas=metadatas,
embeddings=[ embeddings=[
generate_ollama_embeddings( generate_ollama_embeddings(
{"model": RAG_EMBEDDING_MODEL, "prompt": text} GenerateEmbeddingsForm(
**{"model": RAG_EMBEDDING_MODEL, "prompt": text}
)
) )
for text in texts for text in texts
], ],
): ):
collection.add(*batch) collection.add(*batch)
else: else:
collection = CHROMA_CLIENT.create_collection( collection = CHROMA_CLIENT.create_collection(
name=collection_name, name=collection_name,
embedding_function=app.state.sentence_transformer_ef, embedding_function=app.state.sentence_transformer_ef,

View file

@ -32,6 +32,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
def query_embeddings_doc(collection_name: str, query_embeddings, k: int): def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
try: try:
# if you use docker use the model from the environment variable # if you use docker use the model from the environment variable
log.info("query_embeddings_doc", query_embeddings)
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=collection_name, name=collection_name,
) )
@ -117,6 +118,8 @@ def query_collection(
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
results = [] results = []
log.info("query_embeddings_collection", query_embeddings)
for collection_name in collection_names: for collection_name in collection_names:
try: try:
collection = CHROMA_CLIENT.get_collection(name=collection_name) collection = CHROMA_CLIENT.get_collection(name=collection_name)