forked from open-webui/open-webui
feat: chromadb vector store api
This commit is contained in:
parent
b2c9f6dff8
commit
784b369cc9
4 changed files with 119 additions and 11 deletions
1
backend/.gitignore
vendored
1
backend/.gitignore
vendored
|
@ -6,3 +6,4 @@ uploads
|
||||||
*.db
|
*.db
|
||||||
_test
|
_test
|
||||||
Pipfile
|
Pipfile
|
||||||
|
data/*
|
|
@ -1,9 +1,25 @@
|
||||||
from fastapi import FastAPI, Request, Depends, HTTPException
|
from fastapi import FastAPI, Request, Depends, HTTPException, status, UploadFile, File
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from apps.web.routers import auths, users, chats, modelfiles, utils
|
from chromadb.utils import embedding_functions
|
||||||
from config import WEBUI_VERSION, WEBUI_AUTH
|
|
||||||
|
|
||||||
|
from langchain.document_loaders import WebBaseLoader, TextLoader
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_community.vectorstores import Chroma
|
||||||
|
from langchain.chains import RetrievalQA
|
||||||
|
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from config import EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
|
||||||
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
|
EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||||
|
model_name=EMBED_MODEL
|
||||||
|
)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@ -18,6 +34,84 @@ app.add_middleware(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StoreWebForm(BaseModel):
|
||||||
|
url: str
|
||||||
|
collection_name: Optional[str] = "test"
|
||||||
|
|
||||||
|
|
||||||
|
def store_data_in_vector_db(data, collection_name):
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
|
||||||
|
)
|
||||||
|
docs = text_splitter.split_documents(data)
|
||||||
|
|
||||||
|
texts = [doc.page_content for doc in docs]
|
||||||
|
metadatas = [doc.metadata for doc in docs]
|
||||||
|
|
||||||
|
collection = CHROMA_CLIENT.create_collection(
|
||||||
|
name=collection_name, embedding_function=EMBEDDING_FUNC
|
||||||
|
)
|
||||||
|
|
||||||
|
collection.add(
|
||||||
|
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def get_status():
|
async def get_status():
|
||||||
return {"status": True}
|
return {"status": True}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/query/{collection_name}")
|
||||||
|
def query_collection(collection_name: str, query: str, k: Optional[int] = 4):
|
||||||
|
collection = CHROMA_CLIENT.get_collection(
|
||||||
|
name=collection_name,
|
||||||
|
)
|
||||||
|
result = collection.query(query_texts=[query], n_results=k)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/web")
|
||||||
|
def store_web(form_data: StoreWebForm):
|
||||||
|
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
||||||
|
try:
|
||||||
|
loader = WebBaseLoader(form_data.url)
|
||||||
|
data = loader.load()
|
||||||
|
store_data_in_vector_db(data, form_data.collection_name)
|
||||||
|
return {"status": True}
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/doc")
|
||||||
|
def store_doc(file: UploadFile = File(...)):
|
||||||
|
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(file)
|
||||||
|
file.filename = f"{uuid.uuid4()}-{file.filename}"
|
||||||
|
contents = file.file.read()
|
||||||
|
with open(f"./data/{file.filename}", "wb") as f:
|
||||||
|
f.write(contents)
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
# loader = WebBaseLoader(form_data.url)
|
||||||
|
# data = loader.load()
|
||||||
|
# store_data_in_vector_db(data, form_data.collection_name)
|
||||||
|
return {"status": True}
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_vector_db():
|
||||||
|
CHROMA_CLIENT.reset()
|
||||||
|
return {"status": True}
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
from dotenv import load_dotenv, find_dotenv
|
from dotenv import load_dotenv, find_dotenv
|
||||||
|
import os
|
||||||
from constants import ERROR_MESSAGES
|
import chromadb
|
||||||
|
|
||||||
from secrets import token_bytes
|
from secrets import token_bytes
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
|
|
||||||
import os
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
load_dotenv(find_dotenv("../.env"))
|
load_dotenv(find_dotenv("../.env"))
|
||||||
|
|
||||||
|
@ -19,8 +19,9 @@ ENV = os.environ.get("ENV", "dev")
|
||||||
# OLLAMA_API_BASE_URL
|
# OLLAMA_API_BASE_URL
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
OLLAMA_API_BASE_URL = os.environ.get("OLLAMA_API_BASE_URL",
|
OLLAMA_API_BASE_URL = os.environ.get(
|
||||||
"http://localhost:11434/api")
|
"OLLAMA_API_BASE_URL", "http://localhost:11434/api"
|
||||||
|
)
|
||||||
|
|
||||||
if ENV == "prod":
|
if ENV == "prod":
|
||||||
if OLLAMA_API_BASE_URL == "/ollama/api":
|
if OLLAMA_API_BASE_URL == "/ollama/api":
|
||||||
|
@ -56,3 +57,13 @@ WEBUI_JWT_SECRET_KEY = os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t")
|
||||||
|
|
||||||
if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "":
|
if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "":
|
||||||
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
|
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# RAG
|
||||||
|
####################################
|
||||||
|
|
||||||
|
CHROMA_DATA_PATH = "./data/vector_db"
|
||||||
|
EMBED_MODEL = "all-MiniLM-L6-v2"
|
||||||
|
CHROMA_CLIENT = chromadb.PersistentClient(path=CHROMA_DATA_PATH)
|
||||||
|
CHUNK_SIZE = 1500
|
||||||
|
CHUNK_OVERLAP = 100
|
||||||
|
|
|
@ -6,7 +6,6 @@ class MESSAGES(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class ERROR_MESSAGES(str, Enum):
|
class ERROR_MESSAGES(str, Enum):
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return super().__str__()
|
return super().__str__()
|
||||||
|
|
||||||
|
@ -30,7 +29,10 @@ class ERROR_MESSAGES(str, Enum):
|
||||||
UNAUTHORIZED = "401 Unauthorized"
|
UNAUTHORIZED = "401 Unauthorized"
|
||||||
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
|
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
|
||||||
ACTION_PROHIBITED = (
|
ACTION_PROHIBITED = (
|
||||||
"The requested action has been restricted as a security measure.")
|
"The requested action has been restricted as a security measure."
|
||||||
|
)
|
||||||
|
|
||||||
|
FILE_NOT_SENT = "FILE_NOT_SENT"
|
||||||
NOT_FOUND = "We could not find what you're looking for :/"
|
NOT_FOUND = "We could not find what you're looking for :/"
|
||||||
USER_NOT_FOUND = "We could not find what you're looking for :/"
|
USER_NOT_FOUND = "We could not find what you're looking for :/"
|
||||||
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
|
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
|
||||||
|
|
Loading…
Reference in a new issue