forked from open-webui/open-webui
main #3
2 changed files with 79 additions and 33 deletions
|
@ -111,39 +111,6 @@ class StoreWebForm(CollectionNameForm):
|
||||||
url: str
|
url: str
|
||||||
|
|
||||||
|
|
||||||
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
|
||||||
chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
|
|
||||||
)
|
|
||||||
docs = text_splitter.split_documents(data)
|
|
||||||
|
|
||||||
texts = [doc.page_content for doc in docs]
|
|
||||||
metadatas = [doc.metadata for doc in docs]
|
|
||||||
|
|
||||||
try:
|
|
||||||
if overwrite:
|
|
||||||
for collection in CHROMA_CLIENT.list_collections():
|
|
||||||
if collection_name == collection.name:
|
|
||||||
print(f"deleting existing collection {collection_name}")
|
|
||||||
CHROMA_CLIENT.delete_collection(name=collection_name)
|
|
||||||
|
|
||||||
collection = CHROMA_CLIENT.create_collection(
|
|
||||||
name=collection_name,
|
|
||||||
embedding_function=app.state.sentence_transformer_ef,
|
|
||||||
)
|
|
||||||
|
|
||||||
collection.add(
|
|
||||||
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
if e.__class__.__name__ == "UniqueConstraintError":
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def get_status():
|
async def get_status():
|
||||||
return {
|
return {
|
||||||
|
@ -325,6 +292,56 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=app.state.CHUNK_SIZE,
|
||||||
|
chunk_overlap=app.state.CHUNK_OVERLAP,
|
||||||
|
add_start_index=True,
|
||||||
|
)
|
||||||
|
docs = text_splitter.split_documents(data)
|
||||||
|
return store_docs_in_vector_db(docs, collection_name, overwrite)
|
||||||
|
|
||||||
|
|
||||||
|
def store_text_in_vector_db(
|
||||||
|
text, name, collection_name, overwrite: bool = False
|
||||||
|
) -> bool:
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=app.state.CHUNK_SIZE,
|
||||||
|
chunk_overlap=app.state.CHUNK_OVERLAP,
|
||||||
|
add_start_index=True,
|
||||||
|
)
|
||||||
|
docs = text_splitter.create_documents([text], metadatas=[{"name": name}])
|
||||||
|
return store_docs_in_vector_db(docs, collection_name, overwrite)
|
||||||
|
|
||||||
|
|
||||||
|
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
|
||||||
|
texts = [doc.page_content for doc in docs]
|
||||||
|
metadatas = [doc.metadata for doc in docs]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if overwrite:
|
||||||
|
for collection in CHROMA_CLIENT.list_collections():
|
||||||
|
if collection_name == collection.name:
|
||||||
|
print(f"deleting existing collection {collection_name}")
|
||||||
|
CHROMA_CLIENT.delete_collection(name=collection_name)
|
||||||
|
|
||||||
|
collection = CHROMA_CLIENT.create_collection(
|
||||||
|
name=collection_name,
|
||||||
|
embedding_function=app.state.sentence_transformer_ef,
|
||||||
|
)
|
||||||
|
|
||||||
|
collection.add(
|
||||||
|
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
if e.__class__.__name__ == "UniqueConstraintError":
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_loader(filename: str, file_content_type: str, file_path: str):
|
def get_loader(filename: str, file_content_type: str, file_path: str):
|
||||||
file_ext = filename.split(".")[-1].lower()
|
file_ext = filename.split(".")[-1].lower()
|
||||||
known_type = True
|
known_type = True
|
||||||
|
@ -460,6 +477,33 @@ def store_doc(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextRAGForm(BaseModel):
|
||||||
|
name: str
|
||||||
|
content: str
|
||||||
|
collection_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/text")
|
||||||
|
def store_text(
|
||||||
|
form_data: TextRAGForm,
|
||||||
|
user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
|
||||||
|
collection_name = form_data.collection_name
|
||||||
|
if collection_name == None:
|
||||||
|
collection_name = calculate_sha256_string(form_data.content)
|
||||||
|
|
||||||
|
result = store_text_in_vector_db(form_data.content, form_data.name, collection_name)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
return {"status": True, "collection_name": collection_name}
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/scan")
|
@app.get("/scan")
|
||||||
def scan_docs_dir(user=Depends(get_admin_user)):
|
def scan_docs_dir(user=Depends(get_admin_user)):
|
||||||
for path in Path(DOCS_DIR).rglob("./**/*"):
|
for path in Path(DOCS_DIR).rglob("./**/*"):
|
||||||
|
|
|
@ -137,6 +137,8 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
||||||
k=k,
|
k=k,
|
||||||
embedding_function=embedding_function,
|
embedding_function=embedding_function,
|
||||||
)
|
)
|
||||||
|
elif doc["type"] == "text":
|
||||||
|
context = doc["content"]
|
||||||
else:
|
else:
|
||||||
context = query_doc(
|
context = query_doc(
|
||||||
collection_name=doc["collection_name"],
|
collection_name=doc["collection_name"],
|
||||||
|
|
Loading…
Reference in a new issue