diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 2a8b2a49..99aa6959 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -108,7 +108,7 @@ class StoreWebForm(CollectionNameForm): url: str -def store_data_in_vector_db(data, collection_name) -> bool: +def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: text_splitter = RecursiveCharacterTextSplitter( chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP ) @@ -118,6 +118,12 @@ def store_data_in_vector_db(data, collection_name) -> bool: metadatas = [doc.metadata for doc in docs] try: + if overwrite: + for collection in CHROMA_CLIENT.list_collections(): + if collection_name == collection.name: + print(f"deleting existing collection {collection_name}") + CHROMA_CLIENT.delete_collection(name=collection_name) + collection = CHROMA_CLIENT.create_collection( name=collection_name, embedding_function=app.state.sentence_transformer_ef, @@ -355,7 +361,7 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): if collection_name == "": collection_name = calculate_sha256_string(form_data.url)[:63] - store_data_in_vector_db(data, collection_name) + store_data_in_vector_db(data, collection_name, overwrite=True) return { "status": True, "collection_name": collection_name,