feat: chromadb vector store api

This commit is contained in:
Timothy J. Baek 2024-01-06 22:59:22 -08:00
parent b2c9f6dff8
commit 784b369cc9
4 changed files with 119 additions and 11 deletions

1
backend/.gitignore vendored
View file

@ -6,3 +6,4 @@ uploads
*.db *.db
_test _test
Pipfile Pipfile
data/*

View file

@ -1,9 +1,25 @@
from fastapi import FastAPI, Request, Depends, HTTPException from fastapi import FastAPI, Request, Depends, HTTPException, status, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from apps.web.routers import auths, users, chats, modelfiles, utils from chromadb.utils import embedding_functions
from config import WEBUI_VERSION, WEBUI_AUTH
from langchain.document_loaders import WebBaseLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from pydantic import BaseModel
from typing import Optional
import uuid
from config import EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from constants import ERROR_MESSAGES
EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=EMBED_MODEL
)
app = FastAPI() app = FastAPI()
@ -18,6 +34,84 @@ app.add_middleware(
) )
class StoreWebForm(BaseModel):
url: str
collection_name: Optional[str] = "test"
def store_data_in_vector_db(data, collection_name):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
)
docs = text_splitter.split_documents(data)
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
collection = CHROMA_CLIENT.create_collection(
name=collection_name, embedding_function=EMBEDDING_FUNC
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
@app.get("/") @app.get("/")
async def get_status(): async def get_status():
return {"status": True} return {"status": True}
@app.get("/query/{collection_name}")
def query_collection(collection_name: str, query: str, k: Optional[int] = 4):
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
)
result = collection.query(query_texts=[query], n_results=k)
return result
@app.post("/web")
def store_web(form_data: StoreWebForm):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
loader = WebBaseLoader(form_data.url)
data = loader.load()
store_data_in_vector_db(data, form_data.collection_name)
return {"status": True}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@app.post("/doc")
def store_doc(file: UploadFile = File(...)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
print(file)
file.filename = f"{uuid.uuid4()}-{file.filename}"
contents = file.file.read()
with open(f"./data/{file.filename}", "wb") as f:
f.write(contents)
f.close()
# loader = WebBaseLoader(form_data.url)
# data = loader.load()
# store_data_in_vector_db(data, form_data.collection_name)
return {"status": True}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
def reset_vector_db():
CHROMA_CLIENT.reset()
return {"status": True}

View file

@ -1,11 +1,11 @@
from dotenv import load_dotenv, find_dotenv from dotenv import load_dotenv, find_dotenv
import os
from constants import ERROR_MESSAGES import chromadb
from secrets import token_bytes from secrets import token_bytes
from base64 import b64encode from base64 import b64encode
import os from constants import ERROR_MESSAGES
load_dotenv(find_dotenv("../.env")) load_dotenv(find_dotenv("../.env"))
@ -19,8 +19,9 @@ ENV = os.environ.get("ENV", "dev")
# OLLAMA_API_BASE_URL # OLLAMA_API_BASE_URL
#################################### ####################################
OLLAMA_API_BASE_URL = os.environ.get("OLLAMA_API_BASE_URL", OLLAMA_API_BASE_URL = os.environ.get(
"http://localhost:11434/api") "OLLAMA_API_BASE_URL", "http://localhost:11434/api"
)
if ENV == "prod": if ENV == "prod":
if OLLAMA_API_BASE_URL == "/ollama/api": if OLLAMA_API_BASE_URL == "/ollama/api":
@ -56,3 +57,13 @@ WEBUI_JWT_SECRET_KEY = os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t")
if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "": if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
####################################
# RAG
####################################
CHROMA_DATA_PATH = "./data/vector_db"
EMBED_MODEL = "all-MiniLM-L6-v2"
CHROMA_CLIENT = chromadb.PersistentClient(path=CHROMA_DATA_PATH)
CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100

View file

@ -6,7 +6,6 @@ class MESSAGES(str, Enum):
class ERROR_MESSAGES(str, Enum): class ERROR_MESSAGES(str, Enum):
def __str__(self) -> str: def __str__(self) -> str:
return super().__str__() return super().__str__()
@ -30,7 +29,10 @@ class ERROR_MESSAGES(str, Enum):
UNAUTHORIZED = "401 Unauthorized" UNAUTHORIZED = "401 Unauthorized"
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
ACTION_PROHIBITED = ( ACTION_PROHIBITED = (
"The requested action has been restricted as a security measure.") "The requested action has been restricted as a security measure."
)
FILE_NOT_SENT = "FILE_NOT_SENT"
NOT_FOUND = "We could not find what you're looking for :/" NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."