diff --git a/Dockerfile b/Dockerfile index 2e57ee66..aa6c3d55 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ RUN npm ci COPY . . RUN npm run build -FROM python:3.11-slim-buster as base +FROM python:3.11-bookworm as base ENV ENV=prod @@ -28,6 +28,7 @@ WORKDIR /app/backend COPY ./backend/requirements.txt ./requirements.txt RUN pip3 install -r requirements.txt +RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')" COPY ./backend . diff --git a/README.md b/README.md index f35d6678..6d3933b3 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ Also check our sibling project, [OllamaHub](https://ollamahub.com/), where you c - ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction. +- 📚 **RAG Integration (Alpha)**: Immerse yourself in cutting-edge Retrieval Augmented Generation support, revolutionizing your chat experience by seamlessly incorporating document interactions. In its alpha phase, expect occasional issues as we actively refine and enhance this feature to ensure optimal performance and reliability. + - 📜 **Prompt Preset Support**: Instantly access preset prompts using the '/' command in the chat input. Load predefined conversation starters effortlessly and expedite your interactions. Effortlessly import prompts through [OllamaHub](https://ollamahub.com/) integration. - 👍👎 **RLHF Annotation**: Empower your messages by rating them with thumbs up and thumbs down, facilitating the creation of datasets for Reinforcement Learning from Human Feedback (RLHF). Utilize your messages to train or fine-tune models, all while ensuring the confidentiality of locally saved data. @@ -243,7 +245,6 @@ See [TROUBLESHOOTING.md](/TROUBLESHOOTING.md) for information on how to troubles Here are some exciting tasks on our roadmap: -- 📚 **RAG Integration**: Experience first-class retrieval augmented generation support, enabling chat with your documents. - 🌐 **Web Browsing Capability**: Experience the convenience of seamlessly integrating web content directly into your chat. Easily browse and share information without leaving the conversation. - 🔄 **Function Calling**: Empower your interactions by running code directly within the chat. Execute functions and commands effortlessly, enhancing the functionality of your conversations. - ⚙️ **Custom Python Backend Actions**: Empower your Ollama Web UI by creating or downloading custom Python backend actions. Unleash the full potential of your web interface with tailored actions that suit your specific needs, enhancing functionality and versatility. diff --git a/backend/.gitignore b/backend/.gitignore index da641cf7..62a3a06a 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -5,4 +5,5 @@ uploads .ipynb_checkpoints *.db _test -Pipfile \ No newline at end of file +Pipfile +data/* \ No newline at end of file diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py new file mode 100644 index 00000000..a59aac6c --- /dev/null +++ b/backend/apps/rag/main.py @@ -0,0 +1,207 @@ +from fastapi import ( + FastAPI, + Request, + Depends, + HTTPException, + status, + UploadFile, + File, + Form, +) +from fastapi.middleware.cors import CORSMiddleware +import os, shutil + +from chromadb.utils import embedding_functions + +from langchain_community.document_loaders import WebBaseLoader, TextLoader, PyPDFLoader +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.vectorstores import Chroma +from langchain.chains import RetrievalQA + + +from pydantic import BaseModel +from typing import Optional + +import uuid + + +from utils.utils import get_current_user +from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP +from constants import ERROR_MESSAGES + +EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=EMBED_MODEL +) + +app = FastAPI() + +origins = ["*"] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +class CollectionNameForm(BaseModel): + collection_name: Optional[str] = "test" + + +class StoreWebForm(CollectionNameForm): + url: str + + +def store_data_in_vector_db(data, collection_name) -> bool: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP + ) + docs = text_splitter.split_documents(data) + + texts = [doc.page_content for doc in docs] + metadatas = [doc.metadata for doc in docs] + + try: + collection = CHROMA_CLIENT.create_collection( + name=collection_name, embedding_function=EMBEDDING_FUNC + ) + + collection.add( + documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] + ) + return True + except Exception as e: + print(e) + if e.__class__.__name__ == "UniqueConstraintError": + return True + + return False + + +@app.get("/") +async def get_status(): + return {"status": True} + + +@app.get("/query/{collection_name}") +def query_collection( + collection_name: str, + query: str, + k: Optional[int] = 4, + user=Depends(get_current_user), +): + try: + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + ) + result = collection.query(query_texts=[query], n_results=k) + + return result + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +@app.post("/web") +def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): + # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" + try: + loader = WebBaseLoader(form_data.url) + data = loader.load() + store_data_in_vector_db(data, form_data.collection_name) + return {"status": True, "collection_name": form_data.collection_name} + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +@app.post("/doc") +def store_doc( + collection_name: str = Form(...), + file: UploadFile = File(...), + user=Depends(get_current_user), +): + # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" + file.filename = f"{collection_name}-{file.filename}" + + if file.content_type not in ["application/pdf", "text/plain"]: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED, + ) + + try: + filename = file.filename + file_path = f"{UPLOAD_DIR}/{filename}" + contents = file.file.read() + with open(file_path, "wb") as f: + f.write(contents) + f.close() + + if file.content_type == "application/pdf": + loader = PyPDFLoader(file_path) + elif file.content_type == "text/plain": + loader = TextLoader(file_path) + + data = loader.load() + result = store_data_in_vector_db(data, collection_name) + + if result: + return {"status": True, "collection_name": collection_name} + else: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +@app.get("/reset/db") +def reset_vector_db(user=Depends(get_current_user)): + if user.role == "admin": + CHROMA_CLIENT.reset() + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + +@app.get("/reset") +def reset(user=Depends(get_current_user)): + if user.role == "admin": + folder = f"{UPLOAD_DIR}" + for filename in os.listdir(folder): + file_path = os.path.join(folder, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print("Failed to delete %s. Reason: %s" % (file_path, e)) + + try: + CHROMA_CLIENT.reset() + except Exception as e: + print(e) + + return {"status": True} + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) diff --git a/backend/config.py b/backend/config.py index 4c518d13..03718a06 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,14 +1,31 @@ from dotenv import load_dotenv, find_dotenv +import os + + +import chromadb +from chromadb import Settings -from constants import ERROR_MESSAGES from secrets import token_bytes from base64 import b64encode -import os +from constants import ERROR_MESSAGES + + +from pathlib import Path load_dotenv(find_dotenv("../.env")) + +#################################### +# File Upload +#################################### + + +UPLOAD_DIR = "./data/uploads" +Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) + + #################################### # ENV (dev,test,prod) #################################### @@ -19,8 +36,9 @@ ENV = os.environ.get("ENV", "dev") # OLLAMA_API_BASE_URL #################################### -OLLAMA_API_BASE_URL = os.environ.get("OLLAMA_API_BASE_URL", - "http://localhost:11434/api") +OLLAMA_API_BASE_URL = os.environ.get( + "OLLAMA_API_BASE_URL", "http://localhost:11434/api" +) if ENV == "prod": if OLLAMA_API_BASE_URL == "/ollama/api": @@ -56,3 +74,15 @@ WEBUI_JWT_SECRET_KEY = os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t") if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "": raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) + +#################################### +# RAG +#################################### + +CHROMA_DATA_PATH = "./data/vector_db" +EMBED_MODEL = "all-MiniLM-L6-v2" +CHROMA_CLIENT = chromadb.PersistentClient( + path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True) +) +CHUNK_SIZE = 1500 +CHUNK_OVERLAP = 100 diff --git a/backend/constants.py b/backend/constants.py index c3fd0dc5..0f7a46a0 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -6,7 +6,6 @@ class MESSAGES(str, Enum): class ERROR_MESSAGES(str, Enum): - def __str__(self) -> str: return super().__str__() @@ -30,7 +29,12 @@ class ERROR_MESSAGES(str, Enum): UNAUTHORIZED = "401 Unauthorized" ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." ACTION_PROHIBITED = ( - "The requested action has been restricted as a security measure.") + "The requested action has been restricted as a security measure." + ) + + FILE_NOT_SENT = "FILE_NOT_SENT" + FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format (e.g., JPG, PNG, PDF, TXT) and try again." + NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." diff --git a/backend/main.py b/backend/main.py index 0315e5f5..e4d4bdb5 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,3 +1,5 @@ +import time + from fastapi import FastAPI, Request from fastapi.staticfiles import StaticFiles from fastapi import HTTPException @@ -5,16 +7,17 @@ from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException + from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app from apps.web.main import app as webui_app +from apps.rag.main import app as rag_app -import time +from config import ENV class SPAStaticFiles(StaticFiles): - async def get_response(self, path: str, scope): try: return await super().get_response(path, scope) @@ -25,7 +28,7 @@ class SPAStaticFiles(StaticFiles): raise ex -app = FastAPI() +app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) origins = ["*"] @@ -49,9 +52,10 @@ async def check_url(request: Request, call_next): app.mount("/api/v1", webui_app) + app.mount("/ollama/api", ollama_app) app.mount("/openai/api", openai_app) +app.mount("/rag/api/v1", rag_app) -app.mount("/", - SPAStaticFiles(directory="../build", html=True), - name="spa-static-files") + +app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files") diff --git a/backend/requirements.txt b/backend/requirements.txt index 6da59fb6..d3355b5f 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -16,6 +16,13 @@ aiohttp peewee bcrypt + +langchain +langchain-community +chromadb +sentence_transformers +pypdf + PyJWT pyjwt[crypto] diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts new file mode 100644 index 00000000..bafd0360 --- /dev/null +++ b/src/lib/apis/rag/index.ts @@ -0,0 +1,105 @@ +import { RAG_API_BASE_URL } from '$lib/constants'; + +export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => { + const data = new FormData(); + data.append('file', file); + data.append('collection_name', collection_name); + + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/doc`, { + method: 'POST', + headers: { + Accept: 'application/json', + authorization: `Bearer ${token}` + }, + body: data + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const uploadWebToVectorDB = async (token: string, collection_name: string, url: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/web`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + url: url, + collection_name: collection_name + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const queryVectorDB = async ( + token: string, + collection_name: string, + query: string, + k: number +) => { + let error = null; + const searchParams = new URLSearchParams(); + + searchParams.set('query', query); + if (k) { + searchParams.set('k', k.toString()); + } + + const res = await fetch( + `${RAG_API_BASE_URL}/query/${collection_name}/?${searchParams.toString()}`, + { + method: 'GET', + headers: { + Accept: 'application/json', + authorization: `Bearer ${token}` + } + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 1468310d..36511e59 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -2,10 +2,11 @@ import toast from 'svelte-french-toast'; import { onMount, tick } from 'svelte'; import { settings } from '$lib/stores'; - import { findWordIndices } from '$lib/utils'; + import { calculateSHA256, findWordIndices } from '$lib/utils'; import Prompts from './MessageInput/PromptCommands.svelte'; import Suggestions from './MessageInput/Suggestions.svelte'; + import { uploadDocToVectorDB } from '$lib/apis/rag'; export let submitPrompt: Function; export let stopResponse: Function; @@ -98,7 +99,7 @@ dragged = true; }); - dropZone.addEventListener('drop', (e) => { + dropZone.addEventListener('drop', async (e) => { e.preventDefault(); console.log(e); @@ -115,14 +116,32 @@ ]; }; - if ( - e.dataTransfer?.files && - e.dataTransfer?.files.length > 0 && - ['image/gif', 'image/jpeg', 'image/png'].includes(e.dataTransfer?.files[0]['type']) - ) { - reader.readAsDataURL(e.dataTransfer?.files[0]); + const inputFiles = e.dataTransfer?.files; + + if (inputFiles && inputFiles.length > 0) { + const file = inputFiles[0]; + if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) { + reader.readAsDataURL(file); + } else if (['application/pdf', 'text/plain'].includes(file['type'])) { + console.log(file); + const hash = (await calculateSHA256(file)).substring(0, 63); + const res = await uploadDocToVectorDB(localStorage.token, hash, file); + + if (res) { + files = [ + ...files, + { + type: 'doc', + name: file.name, + collection_name: res.collection_name + } + ]; + } + } else { + toast.error(`Unsupported File Type '${file['type']}'.`); + } } else { - toast.error(`Unsupported File Type '${e.dataTransfer?.files[0]['type']}'.`); + toast.error(`File not found.`); } } @@ -145,11 +164,11 @@