forked from open-webui/open-webui
feat: external embeddings support
This commit is contained in:
parent
8b10b058e5
commit
2952e61167
6 changed files with 312 additions and 118 deletions
|
@ -654,6 +654,55 @@ async def generate_embeddings(
|
|||
)
|
||||
|
||||
|
||||
def generate_ollama_embeddings(
|
||||
form_data: GenerateEmbeddingsForm,
|
||||
url_idx: Optional[int] = None,
|
||||
):
|
||||
if url_idx == None:
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||
)
|
||||
|
||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
try:
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/embeddings",
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
data = r.json()
|
||||
|
||||
if "embedding" in data:
|
||||
return data["embedding"]
|
||||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise error_detail
|
||||
|
||||
|
||||
class GenerateCompletionForm(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
|
|
|
@ -39,13 +39,21 @@ import uuid
|
|||
import json
|
||||
|
||||
|
||||
from apps.ollama.main import generate_ollama_embeddings
|
||||
|
||||
from apps.web.models.documents import (
|
||||
Documents,
|
||||
DocumentForm,
|
||||
DocumentResponse,
|
||||
)
|
||||
|
||||
from apps.rag.utils import query_doc, query_collection, get_embedding_model_path
|
||||
from apps.rag.utils import (
|
||||
query_doc,
|
||||
query_embeddings_doc,
|
||||
query_collection,
|
||||
query_embeddings_collection,
|
||||
get_embedding_model_path,
|
||||
)
|
||||
|
||||
from utils.misc import (
|
||||
calculate_sha256,
|
||||
|
@ -58,6 +66,7 @@ from config import (
|
|||
SRC_LOG_LEVELS,
|
||||
UPLOAD_DIR,
|
||||
DOCS_DIR,
|
||||
RAG_EMBEDDING_ENGINE,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
DEVICE_TYPE,
|
||||
|
@ -74,17 +83,20 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
app.state.PDF_EXTRACT_IMAGES = False
|
||||
|
||||
app.state.TOP_K = 4
|
||||
app.state.CHUNK_SIZE = CHUNK_SIZE
|
||||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
|
||||
|
||||
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
|
||||
|
||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.PDF_EXTRACT_IMAGES = False
|
||||
|
||||
|
||||
app.state.TOP_K = 4
|
||||
|
||||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=get_embedding_model_path(
|
||||
|
@ -121,6 +133,7 @@ async def get_status():
|
|||
"chunk_size": app.state.CHUNK_SIZE,
|
||||
"chunk_overlap": app.state.CHUNK_OVERLAP,
|
||||
"template": app.state.RAG_TEMPLATE,
|
||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||
}
|
||||
|
||||
|
@ -252,6 +265,17 @@ def query_doc_handler(
|
|||
):
|
||||
|
||||
try:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query}
|
||||
)
|
||||
|
||||
return query_embeddings_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query_embeddings=query_embeddings,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
else:
|
||||
return query_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
|
@ -277,12 +301,30 @@ def query_collection_handler(
|
|||
form_data: QueryCollectionsForm,
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query}
|
||||
)
|
||||
|
||||
return query_embeddings_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
query_embeddings=query_embeddings,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
else:
|
||||
return query_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
query=form_data.query,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/web")
|
||||
|
@ -317,6 +359,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
|
|||
chunk_overlap=app.state.CHUNK_OVERLAP,
|
||||
add_start_index=True,
|
||||
)
|
||||
|
||||
docs = text_splitter.split_documents(data)
|
||||
|
||||
if len(docs) > 0:
|
||||
|
@ -337,7 +380,9 @@ def store_text_in_vector_db(
|
|||
return store_docs_in_vector_db(docs, collection_name, overwrite)
|
||||
|
||||
|
||||
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
|
||||
async def store_docs_in_vector_db(
|
||||
docs, collection_name, overwrite: bool = False
|
||||
) -> bool:
|
||||
|
||||
texts = [doc.page_content for doc in docs]
|
||||
metadatas = [doc.metadata for doc in docs]
|
||||
|
@ -349,6 +394,22 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|||
log.info(f"deleting existing collection {collection_name}")
|
||||
CHROMA_CLIENT.delete_collection(name=collection_name)
|
||||
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||
|
||||
for batch in create_batches(
|
||||
api=CHROMA_CLIENT,
|
||||
ids=[str(uuid.uuid1()) for _ in texts],
|
||||
metadatas=metadatas,
|
||||
embeddings=[
|
||||
generate_ollama_embeddings(
|
||||
{"model": RAG_EMBEDDING_MODEL, "prompt": text}
|
||||
)
|
||||
for text in texts
|
||||
],
|
||||
):
|
||||
collection.add(*batch)
|
||||
else:
|
||||
collection = CHROMA_CLIENT.create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
|
|
|
@ -2,6 +2,9 @@ import os
|
|||
import re
|
||||
import logging
|
||||
from typing import List
|
||||
import requests
|
||||
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||
|
@ -26,6 +29,21 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
|
|||
raise e
|
||||
|
||||
|
||||
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
|
||||
try:
|
||||
# if you use docker use the model from the environment variable
|
||||
collection = CHROMA_CLIENT.get_collection(
|
||||
name=collection_name,
|
||||
)
|
||||
result = collection.query(
|
||||
query_embeddings=[query_embeddings],
|
||||
n_results=k,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def merge_and_sort_query_results(query_results, k):
|
||||
# Initialize lists to store combined data
|
||||
combined_ids = []
|
||||
|
@ -96,6 +114,24 @@ def query_collection(
|
|||
return merge_and_sort_query_results(results, k)
|
||||
|
||||
|
||||
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
|
||||
|
||||
results = []
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
||||
|
||||
result = collection.query(
|
||||
query_embeddings=[query_embeddings],
|
||||
n_results=k,
|
||||
)
|
||||
results.append(result)
|
||||
except:
|
||||
pass
|
||||
|
||||
return merge_and_sort_query_results(results, k)
|
||||
|
||||
|
||||
def rag_template(template: str, context: str, query: str):
|
||||
template = template.replace("[context]", context)
|
||||
template = template.replace("[query]", query)
|
||||
|
|
|
@ -405,6 +405,9 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
|
|||
|
||||
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
||||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
|
||||
|
||||
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
|
||||
|
||||
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
|
||||
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
|
||||
|
||||
|
|
|
@ -220,6 +220,32 @@ export const generatePrompt = async (token: string = '', model: string, conversa
|
|||
return res;
|
||||
};
|
||||
|
||||
export const generateEmbeddings = async (token: string = '', model: string, text: string) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${OLLAMA_API_BASE_URL}/api/embeddings`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: model,
|
||||
prompt: text
|
||||
})
|
||||
}).catch((err) => {
|
||||
error = err;
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const generateTextCompletion = async (token: string = '', model: string, text: string) => {
|
||||
let error = null;
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
|
||||
let showResetConfirm = false;
|
||||
|
||||
let embeddingEngine = '';
|
||||
let chunkSize = 0;
|
||||
let chunkOverlap = 0;
|
||||
let pdfExtractImages = true;
|
||||
|
@ -119,58 +120,25 @@
|
|||
<div class=" mb-2 text-sm font-medium">{$i18n.t('General Settings')}</div>
|
||||
|
||||
<div class=" flex w-full justify-between">
|
||||
<div class=" self-center text-xs font-medium">
|
||||
{$i18n.t('Scan for documents from {{path}}', { path: '/data/docs' })}
|
||||
</div>
|
||||
|
||||
<button
|
||||
class=" self-center text-xs p-1 px-3 bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 rounded-lg flex flex-row space-x-1 items-center {scanDirLoading
|
||||
? ' cursor-not-allowed'
|
||||
: ''}"
|
||||
on:click={() => {
|
||||
scanHandler();
|
||||
console.log('check');
|
||||
}}
|
||||
type="button"
|
||||
disabled={scanDirLoading}
|
||||
<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Engine')}</div>
|
||||
<div class="flex items-center relative">
|
||||
<select
|
||||
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
|
||||
bind:value={embeddingEngine}
|
||||
placeholder="Select an embedding engine"
|
||||
>
|
||||
<div class="self-center font-medium">{$i18n.t('Scan')}</div>
|
||||
|
||||
{#if scanDirLoading}
|
||||
<div class="ml-3 self-center">
|
||||
<svg
|
||||
class=" w-3 h-3"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
><style>
|
||||
.spinner_ajPY {
|
||||
transform-origin: center;
|
||||
animation: spinner_AtaB 0.75s infinite linear;
|
||||
}
|
||||
@keyframes spinner_AtaB {
|
||||
100% {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
</style><path
|
||||
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
|
||||
opacity=".25"
|
||||
/><path
|
||||
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
|
||||
class="spinner_ajPY"
|
||||
/></svg
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
<option value="">{$i18n.t('Default (SentenceTransformer)')}</option>
|
||||
<option value="ollama">{$i18n.t('Ollama')}</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<hr class=" dark:border-gray-700" />
|
||||
|
||||
<div class="space-y-2">
|
||||
<div>
|
||||
{#if embeddingEngine === 'ollama'}
|
||||
<div>da</div>
|
||||
{:else}
|
||||
<div class=" mb-2 text-sm font-medium">{$i18n.t('Update Embedding Model')}</div>
|
||||
<div class="flex w-full">
|
||||
<div class="flex-1 mr-2">
|
||||
|
@ -238,6 +206,57 @@
|
|||
'Warning: If you update or change your embedding model, you will need to re-import all documents.'
|
||||
)}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<hr class=" dark:border-gray-700 my-3" />
|
||||
|
||||
<div class=" flex w-full justify-between">
|
||||
<div class=" self-center text-xs font-medium">
|
||||
{$i18n.t('Scan for documents from {{path}}', { path: '/data/docs' })}
|
||||
</div>
|
||||
|
||||
<button
|
||||
class=" self-center text-xs p-1 px-3 bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 rounded-lg flex flex-row space-x-1 items-center {scanDirLoading
|
||||
? ' cursor-not-allowed'
|
||||
: ''}"
|
||||
on:click={() => {
|
||||
scanHandler();
|
||||
console.log('check');
|
||||
}}
|
||||
type="button"
|
||||
disabled={scanDirLoading}
|
||||
>
|
||||
<div class="self-center font-medium">{$i18n.t('Scan')}</div>
|
||||
|
||||
{#if scanDirLoading}
|
||||
<div class="ml-3 self-center">
|
||||
<svg
|
||||
class=" w-3 h-3"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
><style>
|
||||
.spinner_ajPY {
|
||||
transform-origin: center;
|
||||
animation: spinner_AtaB 0.75s infinite linear;
|
||||
}
|
||||
@keyframes spinner_AtaB {
|
||||
100% {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
</style><path
|
||||
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
|
||||
opacity=".25"
|
||||
/><path
|
||||
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
|
||||
class="spinner_ajPY"
|
||||
/></svg
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<hr class=" dark:border-gray-700 my-3" />
|
||||
|
||||
|
|
Loading…
Reference in a new issue