Merge pull request #1554 from open-webui/external-embeddings

feat: external embeddings
This commit is contained in:
Timothy Jaeryang Baek 2024-04-14 16:57:57 -07:00 committed by GitHub
commit 54a4b7db14
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 288 additions and 101 deletions

View file

@ -659,7 +659,7 @@ def generate_ollama_embeddings(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
): ):
log.info("generate_ollama_embeddings", form_data) log.info(f"generate_ollama_embeddings {form_data}")
if url_idx == None: if url_idx == None:
model = form_data.model model = form_data.model
@ -688,7 +688,7 @@ def generate_ollama_embeddings(
data = r.json() data = r.json()
log.info("generate_ollama_embeddings", data) log.info(f"generate_ollama_embeddings {data}")
if "embedding" in data: if "embedding" in data:
return data["embedding"] return data["embedding"]

View file

@ -53,6 +53,7 @@ from apps.rag.utils import (
query_collection, query_collection,
query_embeddings_collection, query_embeddings_collection,
get_embedding_model_path, get_embedding_model_path,
generate_openai_embeddings,
) )
from utils.misc import ( from utils.misc import (
@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
app.state.RAG_OPENAI_API_KEY = ""
app.state.PDF_EXTRACT_IMAGES = False app.state.PDF_EXTRACT_IMAGES = False
@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.RAG_OPENAI_API_BASE_URL,
"key": app.state.RAG_OPENAI_API_KEY,
},
} }
class OpenAIConfigForm(BaseModel):
url: str
key: str
class EmbeddingModelUpdateForm(BaseModel): class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
embedding_engine: str embedding_engine: str
embedding_model: str embedding_model: str
@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel):
async def update_embedding_config( async def update_embedding_config(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
): ):
log.info( log.info(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
) )
try: try:
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
if app.state.RAG_EMBEDDING_ENGINE == "ollama": if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = None app.state.sentence_transformer_ef = None
if form_data.openai_config != None:
app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
else: else:
sentence_transformer_ef = ( sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction( embedding_functions.SentenceTransformerEmbeddingFunction(
@ -183,6 +198,10 @@ async def update_embedding_config(
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.RAG_OPENAI_API_BASE_URL,
"key": app.state.RAG_OPENAI_API_KEY,
},
} }
except Exception as e: except Exception as e:
@ -275,6 +294,14 @@ def query_doc_handler(
): ):
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "":
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
else:
if app.state.RAG_EMBEDDING_ENGINE == "ollama": if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings( query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm( GenerateEmbeddingsForm(
@ -284,19 +311,20 @@ def query_doc_handler(
} }
) )
) )
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query,
key=app.state.RAG_OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL,
)
return query_embeddings_doc( return query_embeddings_doc(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query_embeddings=query_embeddings, query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.TOP_K,
) )
else:
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
@ -317,6 +345,15 @@ def query_collection_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "":
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
else:
if app.state.RAG_EMBEDDING_ENGINE == "ollama": if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings( query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm( GenerateEmbeddingsForm(
@ -326,19 +363,20 @@ def query_collection_handler(
} }
) )
) )
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query,
key=app.state.RAG_OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL,
)
return query_embeddings_collection( return query_embeddings_collection(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
query_embeddings=query_embeddings, query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.TOP_K,
) )
else:
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
@ -383,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
docs = text_splitter.split_documents(data) docs = text_splitter.split_documents(data)
if len(docs) > 0: if len(docs) > 0:
log.info("store_data_in_vector_db", "store_docs_in_vector_db") log.info(f"store_data_in_vector_db {docs}")
return store_docs_in_vector_db(docs, collection_name, overwrite), None return store_docs_in_vector_db(docs, collection_name, overwrite), None
else: else:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
@ -402,7 +440,7 @@ def store_text_in_vector_db(
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
log.info("store_docs_in_vector_db", docs, collection_name) log.info(f"store_docs_in_vector_db {docs} {collection_name}")
texts = [doc.page_content for doc in docs] texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
@ -414,24 +452,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"deleting existing collection {collection_name}") log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name) CHROMA_CLIENT.delete_collection(name=collection_name)
if app.state.RAG_EMBEDDING_ENGINE == "ollama": if app.state.RAG_EMBEDDING_ENGINE == "":
collection = CHROMA_CLIENT.create_collection(name=collection_name)
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
embeddings=[
generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{"model": RAG_EMBEDDING_MODEL, "prompt": text}
)
)
for text in texts
],
):
collection.add(*batch)
else:
collection = CHROMA_CLIENT.create_collection( collection = CHROMA_CLIENT.create_collection(
name=collection_name, name=collection_name,
@ -446,6 +467,38 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
): ):
collection.add(*batch) collection.add(*batch)
else:
collection = CHROMA_CLIENT.create_collection(name=collection_name)
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
embeddings = [
generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
)
)
for text in texts
]
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
embeddings = [
generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=text,
key=app.state.RAG_OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL,
)
for text in texts
]
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
embeddings=embeddings,
documents=texts,
):
collection.add(*batch)
return True return True
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)

View file

@ -6,9 +6,12 @@ import requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
from config import SRC_LOG_LEVELS, CHROMA_CLIENT from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
def query_embeddings_doc(collection_name: str, query_embeddings, k: int): def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
try: try:
# if you use docker use the model from the environment variable # if you use docker use the model from the environment variable
log.info("query_embeddings_doc", query_embeddings) log.info(f"query_embeddings_doc {query_embeddings}")
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=collection_name, name=collection_name,
) )
@ -40,6 +43,8 @@ def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
query_embeddings=[query_embeddings], query_embeddings=[query_embeddings],
n_results=k, n_results=k,
) )
log.info(f"query_embeddings_doc:result {result}")
return result return result
except Exception as e: except Exception as e:
raise e raise e
@ -118,7 +123,7 @@ def query_collection(
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
results = [] results = []
log.info("query_embeddings_collection", query_embeddings) log.info(f"query_embeddings_collection {query_embeddings}")
for collection_name in collection_names: for collection_name in collection_names:
try: try:
@ -141,8 +146,20 @@ def rag_template(template: str, context: str, query: str):
return template return template
def rag_messages(docs, messages, template, k, embedding_function): def rag_messages(
log.debug(f"docs: {docs}") docs,
messages,
template,
k,
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
):
log.debug(
f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {openai_key} {openai_url}"
)
last_user_message_idx = None last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1): for i in range(len(messages) - 1, -1, -1):
@ -175,6 +192,11 @@ def rag_messages(docs, messages, template, k, embedding_function):
context = None context = None
try: try:
if doc["type"] == "text":
context = doc["content"]
else:
if embedding_engine == "":
if doc["type"] == "collection": if doc["type"] == "collection":
context = query_collection( context = query_collection(
collection_names=doc["collection_names"], collection_names=doc["collection_names"],
@ -182,8 +204,6 @@ def rag_messages(docs, messages, template, k, embedding_function):
k=k, k=k,
embedding_function=embedding_function, embedding_function=embedding_function,
) )
elif doc["type"] == "text":
context = doc["content"]
else: else:
context = query_doc( context = query_doc(
collection_name=doc["collection_name"], collection_name=doc["collection_name"],
@ -191,6 +211,38 @@ def rag_messages(docs, messages, template, k, embedding_function):
k=k, k=k,
embedding_function=embedding_function, embedding_function=embedding_function,
) )
else:
if embedding_engine == "ollama":
query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
)
elif embedding_engine == "openai":
query_embeddings = generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
if doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
query_embeddings=query_embeddings,
k=k,
)
else:
context = query_embeddings_doc(
collection_name=doc["collection_name"],
query_embeddings=query_embeddings,
k=k,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
context = None context = None
@ -269,3 +321,26 @@ def get_embedding_model_path(
except Exception as e: except Exception as e:
log.exception(f"Cannot determine embedding model snapshot path: {e}") log.exception(f"Cannot determine embedding model snapshot path: {e}")
return embedding_model return embedding_model
def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com"
):
try:
r = requests.post(
f"{url}/v1/embeddings",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": text, "model": model},
)
r.raise_for_status()
data = r.json()
if "data" in data:
return data["data"][0]["embedding"]
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
return None

View file

@ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware):
data["messages"], data["messages"],
rag_app.state.RAG_TEMPLATE, rag_app.state.RAG_TEMPLATE,
rag_app.state.TOP_K, rag_app.state.TOP_K,
rag_app.state.RAG_EMBEDDING_ENGINE,
rag_app.state.RAG_EMBEDDING_MODEL,
rag_app.state.sentence_transformer_ef, rag_app.state.sentence_transformer_ef,
rag_app.state.RAG_OPENAI_API_KEY,
rag_app.state.RAG_OPENAI_API_BASE_URL,
) )
del data["docs"] del data["docs"]

View file

@ -373,7 +373,13 @@ export const getEmbeddingConfig = async (token: string) => {
return res; return res;
}; };
type OpenAIConfigForm = {
key: string;
url: string;
};
type EmbeddingModelUpdateForm = { type EmbeddingModelUpdateForm = {
openai_config?: OpenAIConfigForm;
embedding_engine: string; embedding_engine: string;
embedding_model: string; embedding_model: string;
}; };

View file

@ -29,6 +29,9 @@
let embeddingEngine = ''; let embeddingEngine = '';
let embeddingModel = ''; let embeddingModel = '';
let openAIKey = '';
let openAIUrl = '';
let chunkSize = 0; let chunkSize = 0;
let chunkOverlap = 0; let chunkOverlap = 0;
let pdfExtractImages = true; let pdfExtractImages = true;
@ -50,15 +53,6 @@
}; };
const embeddingModelUpdateHandler = async () => { const embeddingModelUpdateHandler = async () => {
if (embeddingModel === '') {
toast.error(
$i18n.t(
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
)
);
return;
}
if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) { if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) {
toast.error( toast.error(
$i18n.t( $i18n.t(
@ -67,21 +61,46 @@
); );
return; return;
} }
if (embeddingEngine === 'ollama' && embeddingModel === '') {
toast.error(
$i18n.t(
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
)
);
return;
}
if (embeddingEngine === 'openai' && embeddingModel === '') {
toast.error(
$i18n.t(
'Model filesystem path detected. Model shortname is required for update, cannot continue.'
)
);
return;
}
if ((embeddingEngine === 'openai' && openAIKey === '') || openAIUrl === '') {
toast.error($i18n.t('OpenAI URL/Key required.'));
return;
}
console.log('Update embedding model attempt:', embeddingModel); console.log('Update embedding model attempt:', embeddingModel);
updateEmbeddingModelLoading = true; updateEmbeddingModelLoading = true;
const res = await updateEmbeddingConfig(localStorage.token, { const res = await updateEmbeddingConfig(localStorage.token, {
embedding_engine: embeddingEngine, embedding_engine: embeddingEngine,
embedding_model: embeddingModel embedding_model: embeddingModel,
...(embeddingEngine === 'openai'
? {
openai_config: {
key: openAIKey,
url: openAIUrl
}
}
: {})
}).catch(async (error) => { }).catch(async (error) => {
toast.error(error); toast.error(error);
await setEmbeddingConfig();
const embeddingConfig = await getEmbeddingConfig(localStorage.token);
if (embeddingConfig) {
embeddingEngine = embeddingConfig.embedding_engine;
embeddingModel = embeddingConfig.embedding_model;
}
return null; return null;
}); });
updateEmbeddingModelLoading = false; updateEmbeddingModelLoading = false;
@ -89,7 +108,7 @@
if (res) { if (res) {
console.log('embeddingModelUpdateHandler:', res); console.log('embeddingModelUpdateHandler:', res);
if (res.status === true) { if (res.status === true) {
toast.success($i18n.t('Model {{embedding_model}} update complete!', res), { toast.success($i18n.t('Embedding model set to "{{embedding_model}}"', res), {
duration: 1000 * 10 duration: 1000 * 10
}); });
} }
@ -107,6 +126,18 @@
querySettings = await updateQuerySettings(localStorage.token, querySettings); querySettings = await updateQuerySettings(localStorage.token, querySettings);
}; };
const setEmbeddingConfig = async () => {
const embeddingConfig = await getEmbeddingConfig(localStorage.token);
if (embeddingConfig) {
embeddingEngine = embeddingConfig.embedding_engine;
embeddingModel = embeddingConfig.embedding_model;
openAIKey = embeddingConfig.openai_config.key;
openAIUrl = embeddingConfig.openai_config.url;
}
};
onMount(async () => { onMount(async () => {
const res = await getRAGConfig(localStorage.token); const res = await getRAGConfig(localStorage.token);
@ -117,12 +148,7 @@
chunkOverlap = res.chunk.chunk_overlap; chunkOverlap = res.chunk.chunk_overlap;
} }
const embeddingConfig = await getEmbeddingConfig(localStorage.token); await setEmbeddingConfig();
if (embeddingConfig) {
embeddingEngine = embeddingConfig.embedding_engine;
embeddingModel = embeddingConfig.embedding_model;
}
querySettings = await getQuerySettings(localStorage.token); querySettings = await getQuerySettings(localStorage.token);
}); });
@ -146,15 +172,38 @@
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right" class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
bind:value={embeddingEngine} bind:value={embeddingEngine}
placeholder="Select an embedding engine" placeholder="Select an embedding engine"
on:change={() => { on:change={(e) => {
if (e.target.value === 'ollama') {
embeddingModel = ''; embeddingModel = '';
} else if (e.target.value === 'openai') {
embeddingModel = 'text-embedding-3-small';
}
}} }}
> >
<option value="">{$i18n.t('Default (SentenceTransformer)')}</option> <option value="">{$i18n.t('Default (SentenceTransformer)')}</option>
<option value="ollama">{$i18n.t('Ollama')}</option> <option value="ollama">{$i18n.t('Ollama')}</option>
<option value="openai">{$i18n.t('OpenAI')}</option>
</select> </select>
</div> </div>
</div> </div>
{#if embeddingEngine === 'openai'}
<div class="mt-1 flex gap-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('API Base URL')}
bind:value={openAIUrl}
required
/>
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('API Key')}
bind:value={openAIKey}
required
/>
</div>
{/if}
</div> </div>
<div class="space-y-2"> <div class="space-y-2">