forked from open-webui/open-webui
		
	feat: chromadb vector store api
This commit is contained in:
		
							parent
							
								
									b2c9f6dff8
								
							
						
					
					
						commit
						784b369cc9
					
				
					 4 changed files with 119 additions and 11 deletions
				
			
		
							
								
								
									
										1
									
								
								backend/.gitignore
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								backend/.gitignore
									
										
									
									
										vendored
									
									
								
							|  | @ -6,3 +6,4 @@ uploads | |||
| *.db | ||||
| _test | ||||
| Pipfile | ||||
| data/* | ||||
|  | @ -1,9 +1,25 @@ | |||
| from fastapi import FastAPI, Request, Depends, HTTPException | ||||
| from fastapi import FastAPI, Request, Depends, HTTPException, status, UploadFile, File | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| 
 | ||||
| from apps.web.routers import auths, users, chats, modelfiles, utils | ||||
| from config import WEBUI_VERSION, WEBUI_AUTH | ||||
| from chromadb.utils import embedding_functions | ||||
| 
 | ||||
| from langchain.document_loaders import WebBaseLoader, TextLoader | ||||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | ||||
| from langchain_community.vectorstores import Chroma | ||||
| from langchain.chains import RetrievalQA | ||||
| 
 | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| from typing import Optional | ||||
| 
 | ||||
| import uuid | ||||
| 
 | ||||
| from config import EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|     model_name=EMBED_MODEL | ||||
| ) | ||||
| 
 | ||||
| app = FastAPI() | ||||
| 
 | ||||
|  | @ -18,6 +34,84 @@ app.add_middleware( | |||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class StoreWebForm(BaseModel): | ||||
|     url: str | ||||
|     collection_name: Optional[str] = "test" | ||||
| 
 | ||||
| 
 | ||||
| def store_data_in_vector_db(data, collection_name): | ||||
|     text_splitter = RecursiveCharacterTextSplitter( | ||||
|         chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP | ||||
|     ) | ||||
|     docs = text_splitter.split_documents(data) | ||||
| 
 | ||||
|     texts = [doc.page_content for doc in docs] | ||||
|     metadatas = [doc.metadata for doc in docs] | ||||
| 
 | ||||
|     collection = CHROMA_CLIENT.create_collection( | ||||
|         name=collection_name, embedding_function=EMBEDDING_FUNC | ||||
|     ) | ||||
| 
 | ||||
|     collection.add( | ||||
|         documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/") | ||||
| async def get_status(): | ||||
|     return {"status": True} | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/query/{collection_name}") | ||||
| def query_collection(collection_name: str, query: str, k: Optional[int] = 4): | ||||
|     collection = CHROMA_CLIENT.get_collection( | ||||
|         name=collection_name, | ||||
|     ) | ||||
|     result = collection.query(query_texts=[query], n_results=k) | ||||
| 
 | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/web") | ||||
| def store_web(form_data: StoreWebForm): | ||||
|     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" | ||||
|     try: | ||||
|         loader = WebBaseLoader(form_data.url) | ||||
|         data = loader.load() | ||||
|         store_data_in_vector_db(data, form_data.collection_name) | ||||
|         return {"status": True} | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, | ||||
|             detail=ERROR_MESSAGES.DEFAULT(e), | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/doc") | ||||
| def store_doc(file: UploadFile = File(...)): | ||||
|     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" | ||||
| 
 | ||||
|     try: | ||||
|         print(file) | ||||
|         file.filename = f"{uuid.uuid4()}-{file.filename}" | ||||
|         contents = file.file.read() | ||||
|         with open(f"./data/{file.filename}", "wb") as f: | ||||
|             f.write(contents) | ||||
|             f.close() | ||||
| 
 | ||||
|         # loader = WebBaseLoader(form_data.url) | ||||
|         # data = loader.load() | ||||
|         # store_data_in_vector_db(data, form_data.collection_name) | ||||
|         return {"status": True} | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, | ||||
|             detail=ERROR_MESSAGES.DEFAULT(e), | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def reset_vector_db(): | ||||
|     CHROMA_CLIENT.reset() | ||||
|     return {"status": True} | ||||
|  |  | |||
|  | @ -1,11 +1,11 @@ | |||
| from dotenv import load_dotenv, find_dotenv | ||||
| 
 | ||||
| from constants import ERROR_MESSAGES | ||||
| import os | ||||
| import chromadb | ||||
| 
 | ||||
| from secrets import token_bytes | ||||
| from base64 import b64encode | ||||
| 
 | ||||
| import os | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| load_dotenv(find_dotenv("../.env")) | ||||
| 
 | ||||
|  | @ -19,8 +19,9 @@ ENV = os.environ.get("ENV", "dev") | |||
| # OLLAMA_API_BASE_URL | ||||
| #################################### | ||||
| 
 | ||||
| OLLAMA_API_BASE_URL = os.environ.get("OLLAMA_API_BASE_URL", | ||||
|                                      "http://localhost:11434/api") | ||||
| OLLAMA_API_BASE_URL = os.environ.get( | ||||
|     "OLLAMA_API_BASE_URL", "http://localhost:11434/api" | ||||
| ) | ||||
| 
 | ||||
| if ENV == "prod": | ||||
|     if OLLAMA_API_BASE_URL == "/ollama/api": | ||||
|  | @ -56,3 +57,13 @@ WEBUI_JWT_SECRET_KEY = os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t") | |||
| 
 | ||||
| if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "": | ||||
|     raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) | ||||
| 
 | ||||
| #################################### | ||||
| # RAG | ||||
| #################################### | ||||
| 
 | ||||
| CHROMA_DATA_PATH = "./data/vector_db" | ||||
| EMBED_MODEL = "all-MiniLM-L6-v2" | ||||
| CHROMA_CLIENT = chromadb.PersistentClient(path=CHROMA_DATA_PATH) | ||||
| CHUNK_SIZE = 1500 | ||||
| CHUNK_OVERLAP = 100 | ||||
|  |  | |||
|  | @ -6,7 +6,6 @@ class MESSAGES(str, Enum): | |||
| 
 | ||||
| 
 | ||||
| class ERROR_MESSAGES(str, Enum): | ||||
| 
 | ||||
|     def __str__(self) -> str: | ||||
|         return super().__str__() | ||||
| 
 | ||||
|  | @ -30,7 +29,10 @@ class ERROR_MESSAGES(str, Enum): | |||
|     UNAUTHORIZED = "401 Unauthorized" | ||||
|     ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." | ||||
|     ACTION_PROHIBITED = ( | ||||
|         "The requested action has been restricted as a security measure.") | ||||
|         "The requested action has been restricted as a security measure." | ||||
|     ) | ||||
| 
 | ||||
|     FILE_NOT_SENT = "FILE_NOT_SENT" | ||||
|     NOT_FOUND = "We could not find what you're looking for :/" | ||||
|     USER_NOT_FOUND = "We could not find what you're looking for :/" | ||||
|     API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek