Merge branch 'dev' into embedding-model-fix-and-manual-update

This commit is contained in:
lainedfles 2024-04-08 14:57:54 -06:00 committed by GitHub
commit 506a061387
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
60 changed files with 1906 additions and 520 deletions

View file

@ -58,8 +58,8 @@ from config import (
UPLOAD_DIR,
DOCS_DIR,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
DEVICE_TYPE,
CHROMA_CLIENT,
CHUNK_SIZE,
CHUNK_OVERLAP,
@ -86,7 +86,7 @@ app.state.TOP_K = 4
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
device=DEVICE_TYPE,
)
)
@ -154,7 +154,7 @@ async def update_embedding_model(
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
device=DEVICE_TYPE,
)
)
except Exception as e:
@ -471,25 +471,11 @@ def store_doc(
log.info(f"file.content_type: {file.content_type}")
try:
is_valid_filename = True
unsanitized_filename = file.filename
if re.search(r'[\\/:"\*\?<>|\n\t ]', unsanitized_filename) is not None:
is_valid_filename = False
filename = os.path.basename(unsanitized_filename)
unvalidated_file_path = f"{UPLOAD_DIR}/{unsanitized_filename}"
dereferenced_file_path = str(Path(unvalidated_file_path).resolve(strict=False))
if not dereferenced_file_path.startswith(UPLOAD_DIR):
is_valid_filename = False
file_path = f"{UPLOAD_DIR}/{filename}"
if is_valid_filename:
file_path = dereferenced_file_path
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(),
)
filename = file.filename
contents = file.file.read()
with open(file_path, "wb") as f:
f.write(contents)
@ -500,7 +486,7 @@ def store_doc(
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(file.filename, file.content_type, file_path)
loader, known_type = get_loader(filename, file.content_type, file_path)
data = loader.load()
try: