refac: naming convention

This commit is contained in:
Timothy J. Baek 2024-04-27 15:02:57 -04:00
parent 99a43cc998
commit 9be56d68e0
2 changed files with 6 additions and 6 deletions

View file

@ -48,7 +48,7 @@ from apps.web.models.documents import (
from apps.rag.utils import ( from apps.rag.utils import (
get_model_path, get_model_path,
query_embeddings_doc, query_embeddings_doc,
query_embeddings_function, get_embeddings_function,
query_embeddings_collection, query_embeddings_collection,
) )
@ -367,7 +367,7 @@ def query_doc_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
embeddings_function = query_embeddings_function( embeddings_function = get_embeddings_function(
app.state.RAG_EMBEDDING_ENGINE, app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL, app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
@ -410,7 +410,7 @@ def query_collection_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
embeddings_function = query_embeddings_function( embeddings_function = get_embeddings_function(
app.state.RAG_EMBEDDING_ENGINE, app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL, app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
@ -508,7 +508,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name) collection = CHROMA_CLIENT.create_collection(name=collection_name)
embedding_func = query_embeddings_function( embedding_func = get_embeddings_function(
app.state.RAG_EMBEDDING_ENGINE, app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL, app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,

View file

@ -164,7 +164,7 @@ def rag_template(template: str, context: str, query: str):
return template return template
def query_embeddings_function( def get_embeddings_function(
embedding_engine, embedding_engine,
embedding_model, embedding_model,
embedding_function, embedding_function,
@ -243,7 +243,7 @@ def rag_messages(
content_type = None content_type = None
query = "" query = ""
embeddings_function = query_embeddings_function( embeddings_function = get_embeddings_function(
embedding_engine, embedding_engine,
embedding_model, embedding_model,
embedding_function, embedding_function,