From d5aa9e871045899437b04cfdc36c482adde032a0 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 31 Dec 2023 23:35:17 -0800 Subject: [PATCH 01/19] feat: requirements for RAG --- backend/requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/requirements.txt b/backend/requirements.txt index 2644d559..8e3fb3ed 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -16,5 +16,7 @@ aiohttp peewee bcrypt +chromadb + PyJWT pyjwt[crypto] From b2c9f6dff8cd1b47ffe1ac64022fe149a28164fd Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 6 Jan 2024 22:07:20 -0800 Subject: [PATCH 02/19] feat: rag api endpoint --- backend/apps/rag/main.py | 23 +++++++++++++++++++++++ backend/main.py | 11 +++++++---- backend/requirements.txt | 3 +++ 3 files changed, 33 insertions(+), 4 deletions(-) create mode 100644 backend/apps/rag/main.py diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py new file mode 100644 index 00000000..6870792d --- /dev/null +++ b/backend/apps/rag/main.py @@ -0,0 +1,23 @@ +from fastapi import FastAPI, Request, Depends, HTTPException +from fastapi.middleware.cors import CORSMiddleware + +from apps.web.routers import auths, users, chats, modelfiles, utils +from config import WEBUI_VERSION, WEBUI_AUTH + + +app = FastAPI() + +origins = ["*"] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/") +async def get_status(): + return {"status": True} diff --git a/backend/main.py b/backend/main.py index 0315e5f5..b682aad2 100644 --- a/backend/main.py +++ b/backend/main.py @@ -5,16 +5,18 @@ from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException + from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app from apps.web.main import app as webui_app +from apps.rag.main import app as rag_app + import time class SPAStaticFiles(StaticFiles): - async def get_response(self, path: str, scope): try: return await super().get_response(path, scope) @@ -49,9 +51,10 @@ async def check_url(request: Request, call_next): app.mount("/api/v1", webui_app) + app.mount("/ollama/api", ollama_app) app.mount("/openai/api", openai_app) +app.mount("/rag/api/v1", rag_app) -app.mount("/", - SPAStaticFiles(directory="../build", html=True), - name="spa-static-files") + +app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files") diff --git a/backend/requirements.txt b/backend/requirements.txt index ffdab8fe..7e8c24e6 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -16,7 +16,10 @@ aiohttp peewee bcrypt + +langchain chromadb +sentence_transformers PyJWT pyjwt[crypto] From 784b369cc9279c8249da968d2f8dcefe7951bf9a Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 6 Jan 2024 22:59:22 -0800 Subject: [PATCH 03/19] feat: chromadb vector store api --- backend/.gitignore | 3 +- backend/apps/rag/main.py | 100 +++++++++++++++++++++++++++++++++++++-- backend/config.py | 21 ++++++-- backend/constants.py | 6 ++- 4 files changed, 119 insertions(+), 11 deletions(-) diff --git a/backend/.gitignore b/backend/.gitignore index da641cf7..62a3a06a 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -5,4 +5,5 @@ uploads .ipynb_checkpoints *.db _test -Pipfile \ No newline at end of file +Pipfile +data/* \ No newline at end of file diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 6870792d..7dae9bc2 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -1,9 +1,25 @@ -from fastapi import FastAPI, Request, Depends, HTTPException +from fastapi import FastAPI, Request, Depends, HTTPException, status, UploadFile, File from fastapi.middleware.cors import CORSMiddleware -from apps.web.routers import auths, users, chats, modelfiles, utils -from config import WEBUI_VERSION, WEBUI_AUTH +from chromadb.utils import embedding_functions +from langchain.document_loaders import WebBaseLoader, TextLoader +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.vectorstores import Chroma +from langchain.chains import RetrievalQA + + +from pydantic import BaseModel +from typing import Optional + +import uuid + +from config import EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP +from constants import ERROR_MESSAGES + +EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=EMBED_MODEL +) app = FastAPI() @@ -18,6 +34,84 @@ app.add_middleware( ) +class StoreWebForm(BaseModel): + url: str + collection_name: Optional[str] = "test" + + +def store_data_in_vector_db(data, collection_name): + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP + ) + docs = text_splitter.split_documents(data) + + texts = [doc.page_content for doc in docs] + metadatas = [doc.metadata for doc in docs] + + collection = CHROMA_CLIENT.create_collection( + name=collection_name, embedding_function=EMBEDDING_FUNC + ) + + collection.add( + documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] + ) + + @app.get("/") async def get_status(): return {"status": True} + + +@app.get("/query/{collection_name}") +def query_collection(collection_name: str, query: str, k: Optional[int] = 4): + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + ) + result = collection.query(query_texts=[query], n_results=k) + + return result + + +@app.post("/web") +def store_web(form_data: StoreWebForm): + # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" + try: + loader = WebBaseLoader(form_data.url) + data = loader.load() + store_data_in_vector_db(data, form_data.collection_name) + return {"status": True} + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +@app.post("/doc") +def store_doc(file: UploadFile = File(...)): + # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" + + try: + print(file) + file.filename = f"{uuid.uuid4()}-{file.filename}" + contents = file.file.read() + with open(f"./data/{file.filename}", "wb") as f: + f.write(contents) + f.close() + + # loader = WebBaseLoader(form_data.url) + # data = loader.load() + # store_data_in_vector_db(data, form_data.collection_name) + return {"status": True} + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +def reset_vector_db(): + CHROMA_CLIENT.reset() + return {"status": True} diff --git a/backend/config.py b/backend/config.py index 4c518d13..df57c829 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,11 +1,11 @@ from dotenv import load_dotenv, find_dotenv - -from constants import ERROR_MESSAGES +import os +import chromadb from secrets import token_bytes from base64 import b64encode -import os +from constants import ERROR_MESSAGES load_dotenv(find_dotenv("../.env")) @@ -19,8 +19,9 @@ ENV = os.environ.get("ENV", "dev") # OLLAMA_API_BASE_URL #################################### -OLLAMA_API_BASE_URL = os.environ.get("OLLAMA_API_BASE_URL", - "http://localhost:11434/api") +OLLAMA_API_BASE_URL = os.environ.get( + "OLLAMA_API_BASE_URL", "http://localhost:11434/api" +) if ENV == "prod": if OLLAMA_API_BASE_URL == "/ollama/api": @@ -56,3 +57,13 @@ WEBUI_JWT_SECRET_KEY = os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t") if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "": raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) + +#################################### +# RAG +#################################### + +CHROMA_DATA_PATH = "./data/vector_db" +EMBED_MODEL = "all-MiniLM-L6-v2" +CHROMA_CLIENT = chromadb.PersistentClient(path=CHROMA_DATA_PATH) +CHUNK_SIZE = 1500 +CHUNK_OVERLAP = 100 diff --git a/backend/constants.py b/backend/constants.py index c3fd0dc5..9893744c 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -6,7 +6,6 @@ class MESSAGES(str, Enum): class ERROR_MESSAGES(str, Enum): - def __str__(self) -> str: return super().__str__() @@ -30,7 +29,10 @@ class ERROR_MESSAGES(str, Enum): UNAUTHORIZED = "401 Unauthorized" ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." ACTION_PROHIBITED = ( - "The requested action has been restricted as a security measure.") + "The requested action has been restricted as a security measure." + ) + + FILE_NOT_SENT = "FILE_NOT_SENT" NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." From cd86c369537e949cd01480e93729187ed73688de Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 6 Jan 2024 23:40:51 -0800 Subject: [PATCH 04/19] feat: pdf data load --- backend/apps/rag/main.py | 48 +++++++++++++++++++++++++++++++--------- backend/constants.py | 2 ++ backend/requirements.txt | 1 + 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 7dae9bc2..67b118cd 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -1,9 +1,18 @@ -from fastapi import FastAPI, Request, Depends, HTTPException, status, UploadFile, File +from fastapi import ( + FastAPI, + Request, + Depends, + HTTPException, + status, + UploadFile, + File, + Form, +) from fastapi.middleware.cors import CORSMiddleware from chromadb.utils import embedding_functions -from langchain.document_loaders import WebBaseLoader, TextLoader +from langchain.document_loaders import WebBaseLoader, TextLoader, PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain.chains import RetrievalQA @@ -34,11 +43,14 @@ app.add_middleware( ) -class StoreWebForm(BaseModel): - url: str +class CollectionNameForm(BaseModel): collection_name: Optional[str] = "test" +class StoreWebForm(CollectionNameForm): + url: str + + def store_data_in_vector_db(data, collection_name): text_splitter = RecursiveCharacterTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP @@ -89,20 +101,34 @@ def store_web(form_data: StoreWebForm): @app.post("/doc") -def store_doc(file: UploadFile = File(...)): +def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" + file.filename = f"{uuid.uuid4()}-{file.filename}" + print(dir(file)) + print(file.content_type) + + if file.content_type not in ["application/pdf", "text/plain"]: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED, + ) + try: - print(file) - file.filename = f"{uuid.uuid4()}-{file.filename}" + filename = file.filename + file_path = f"./data/{filename}" contents = file.file.read() - with open(f"./data/{file.filename}", "wb") as f: + with open(file_path, "wb") as f: f.write(contents) f.close() - # loader = WebBaseLoader(form_data.url) - # data = loader.load() - # store_data_in_vector_db(data, form_data.collection_name) + if file.content_type == "application/pdf": + loader = PyPDFLoader(file_path) + elif file.content_type == "text/plain": + loader = TextLoader(file_path) + + data = loader.load() + store_data_in_vector_db(data, collection_name) return {"status": True} except Exception as e: print(e) diff --git a/backend/constants.py b/backend/constants.py index 9893744c..0f7a46a0 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -33,6 +33,8 @@ class ERROR_MESSAGES(str, Enum): ) FILE_NOT_SENT = "FILE_NOT_SENT" + FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format (e.g., JPG, PNG, PDF, TXT) and try again." + NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." diff --git a/backend/requirements.txt b/backend/requirements.txt index 7e8c24e6..8c569536 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -20,6 +20,7 @@ bcrypt langchain chromadb sentence_transformers +pypdf PyJWT pyjwt[crypto] From 3229ec116c6847e82c726ad9b765c07bfe185f8e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 6 Jan 2024 23:52:22 -0800 Subject: [PATCH 05/19] feat: rag apis added to frontend --- src/lib/apis/rag/index.ts | 108 ++++++++++++++++++++++++++++++++++++++ src/lib/constants.ts | 1 + 2 files changed, 109 insertions(+) create mode 100644 src/lib/apis/rag/index.ts diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts new file mode 100644 index 00000000..44ac0430 --- /dev/null +++ b/src/lib/apis/rag/index.ts @@ -0,0 +1,108 @@ +import { RAG_API_BASE_URL } from '$lib/constants'; + +export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => { + const data = new FormData(); + data.append('file', file); + data.append('collection_name', collection_name); + + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/doc`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: data + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const uploadWebToVectorDB = async (token: string, collection_name: string, url: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/web`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + url: url, + collection_name: collection_name + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const queryVectorDB = async ( + token: string, + collection_name: string, + query: string, + k: number +) => { + let error = null; + const searchParams = new URLSearchParams(); + + searchParams.set('query', query); + if (k) { + searchParams.set('k', k.toString()); + } + + const res = await fetch( + `${RAG_API_BASE_URL}/query/${collection_name}/?${searchParams.toString()}`, + { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/constants.ts b/src/lib/constants.ts index 27744197..a43104ad 100644 --- a/src/lib/constants.ts +++ b/src/lib/constants.ts @@ -5,6 +5,7 @@ export const WEBUI_BASE_URL = dev ? `http://${location.hostname}:8080` : ``; export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`; export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`; export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`; +export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`; export const WEB_UI_VERSION = 'v1.0.0-alpha-static'; From fef4725d569efd753bc0b8d216f124bc64c42f3d Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 7 Jan 2024 00:57:10 -0800 Subject: [PATCH 06/19] feat: frontend file upload support --- backend/apps/rag/main.py | 4 +- src/lib/apis/rag/index.ts | 3 - src/lib/components/chat/MessageInput.svelte | 77 +++++++-- src/lib/utils/index.ts | 35 ++++ src/lib/utils/rag/index.ts | 20 +++ src/routes/(app)/+page.svelte | 176 ++++++++++++-------- 6 files changed, 223 insertions(+), 92 deletions(-) create mode 100644 src/lib/utils/rag/index.ts diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 67b118cd..0e7f9b07 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -91,7 +91,7 @@ def store_web(form_data: StoreWebForm): loader = WebBaseLoader(form_data.url) data = loader.load() store_data_in_vector_db(data, form_data.collection_name) - return {"status": True} + return {"status": True, "collection_name": form_data.collection_name} except Exception as e: print(e) raise HTTPException( @@ -129,7 +129,7 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)): data = loader.load() store_data_in_vector_db(data, collection_name) - return {"status": True} + return {"status": True, "collection_name": collection_name} except Exception as e: print(e) raise HTTPException( diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 44ac0430..bafd0360 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -11,7 +11,6 @@ export const uploadDocToVectorDB = async (token: string, collection_name: string method: 'POST', headers: { Accept: 'application/json', - 'Content-Type': 'application/json', authorization: `Bearer ${token}` }, body: data @@ -85,7 +84,6 @@ export const queryVectorDB = async ( method: 'GET', headers: { Accept: 'application/json', - 'Content-Type': 'application/json', authorization: `Bearer ${token}` } } @@ -96,7 +94,6 @@ export const queryVectorDB = async ( }) .catch((err) => { error = err.detail; - console.log(err); return null; }); diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 1468310d..4cce1fcb 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -2,10 +2,11 @@ import toast from 'svelte-french-toast'; import { onMount, tick } from 'svelte'; import { settings } from '$lib/stores'; - import { findWordIndices } from '$lib/utils'; + import { calculateSHA256, findWordIndices } from '$lib/utils'; import Prompts from './MessageInput/PromptCommands.svelte'; import Suggestions from './MessageInput/Suggestions.svelte'; + import { uploadDocToVectorDB } from '$lib/apis/rag'; export let submitPrompt: Function; export let stopResponse: Function; @@ -98,7 +99,7 @@ dragged = true; }); - dropZone.addEventListener('drop', (e) => { + dropZone.addEventListener('drop', async (e) => { e.preventDefault(); console.log(e); @@ -115,14 +116,30 @@ ]; }; - if ( - e.dataTransfer?.files && - e.dataTransfer?.files.length > 0 && - ['image/gif', 'image/jpeg', 'image/png'].includes(e.dataTransfer?.files[0]['type']) - ) { - reader.readAsDataURL(e.dataTransfer?.files[0]); + if (e.dataTransfer?.files && e.dataTransfer?.files.length > 0) { + const file = e.dataTransfer?.files[0]; + if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) { + reader.readAsDataURL(file); + } else if (['application/pdf', 'text/plain'].includes(file['type'])) { + console.log(file); + const hash = await calculateSHA256(file); + // const res = uploadDocToVectorDB(localStorage.token, hash,file); + + if (true) { + files = [ + ...files, + { + type: 'doc', + name: file.name, + collection_name: hash + } + ]; + } + } else { + toast.error(`Unsupported File Type '${file['type']}'.`); + } } else { - toast.error(`Unsupported File Type '${e.dataTransfer?.files[0]['type']}'.`); + toast.error(`File not found.`); } } @@ -145,11 +162,11 @@
-
🏞️
-
Add Images
+
🗂️
+
Add Files
- Drop any images here to add to the conversation + Drop any files/images here to add to the conversation
@@ -237,10 +254,42 @@ }} > {#if files.length > 0} -
+
{#each files as file, fileIdx}
- input + {#if file.type === 'image'} + input + {:else if file.type === 'doc'} +
+
+ + + + +
+ +
+
+ {file.name} +
+ +
Document
+
+
+ {/if}