forked from open-webui/open-webui
		
	Merge pull request #1687 from buroa/buroa/huggingface-embeddings
feat: move to native `sentence_transformers`
This commit is contained in:
		
						commit
						0546ad58be
					
				
					 7 changed files with 153 additions and 268 deletions
				
			
		|  | @ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. | ||||||
| The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), | ||||||
| and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | ||||||
| 
 | 
 | ||||||
|  | ## [0.1.121] - 2024-04-22 | ||||||
|  | 
 | ||||||
|  | ### Added | ||||||
|  | 
 | ||||||
|  | - **🛠️ Improved Embedding Model Support**: You can now use any embedding model `sentence_transformers` supports. | ||||||
|  | 
 | ||||||
| ## [0.1.120] - 2024-04-20 | ## [0.1.120] - 2024-04-20 | ||||||
| 
 | 
 | ||||||
| ### Added | ### Added | ||||||
|  |  | ||||||
							
								
								
									
										12
									
								
								Dockerfile
									
										
									
									
									
								
							
							
						
						
									
										12
									
								
								Dockerfile
									
										
									
									
									
								
							|  | @ -8,8 +8,8 @@ ARG USE_CUDA_VER=cu121 | ||||||
| # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers | # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers | ||||||
| # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard  | # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard  | ||||||
| # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) | # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) | ||||||
| # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. | # IMPORTANT: If you change the default model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. | ||||||
| ARG USE_EMBEDDING_MODEL=all-MiniLM-L6-v2 | ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 | ||||||
| 
 | 
 | ||||||
| ######## WebUI frontend ######## | ######## WebUI frontend ######## | ||||||
| FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build | FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build | ||||||
|  | @ -98,13 +98,13 @@ RUN pip3 install uv && \ | ||||||
|         # If you use CUDA the whisper and embedding model will be downloaded on first use |         # If you use CUDA the whisper and embedding model will be downloaded on first use | ||||||
|         pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \ |         pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \ | ||||||
|         uv pip install --system -r requirements.txt --no-cache-dir && \ |         uv pip install --system -r requirements.txt --no-cache-dir && \ | ||||||
|         python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ |         python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ | ||||||
|         python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \ |         python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ | ||||||
|     else \ |     else \ | ||||||
|         pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ |         pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ | ||||||
|         uv pip install --system -r requirements.txt --no-cache-dir && \ |         uv pip install --system -r requirements.txt --no-cache-dir && \ | ||||||
|         python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ |         python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ | ||||||
|         python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \ |         python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ | ||||||
|     fi |     fi | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -13,7 +13,6 @@ import os, shutil, logging, re | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import List | from typing import List | ||||||
| 
 | 
 | ||||||
| from chromadb.utils import embedding_functions |  | ||||||
| from chromadb.utils.batch_utils import create_batches | from chromadb.utils.batch_utils import create_batches | ||||||
| 
 | 
 | ||||||
| from langchain_community.document_loaders import ( | from langchain_community.document_loaders import ( | ||||||
|  | @ -38,6 +37,7 @@ import mimetypes | ||||||
| import uuid | import uuid | ||||||
| import json | import json | ||||||
| 
 | 
 | ||||||
|  | import sentence_transformers | ||||||
| 
 | 
 | ||||||
| from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm | from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm | ||||||
| 
 | 
 | ||||||
|  | @ -48,11 +48,8 @@ from apps.web.models.documents import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| from apps.rag.utils import ( | from apps.rag.utils import ( | ||||||
|     query_doc, |  | ||||||
|     query_embeddings_doc, |     query_embeddings_doc, | ||||||
|     query_collection, |  | ||||||
|     query_embeddings_collection, |     query_embeddings_collection, | ||||||
|     get_embedding_model_path, |  | ||||||
|     generate_openai_embeddings, |     generate_openai_embeddings, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -69,7 +66,7 @@ from config import ( | ||||||
|     DOCS_DIR, |     DOCS_DIR, | ||||||
|     RAG_EMBEDDING_ENGINE, |     RAG_EMBEDDING_ENGINE, | ||||||
|     RAG_EMBEDDING_MODEL, |     RAG_EMBEDDING_MODEL, | ||||||
|     RAG_EMBEDDING_MODEL_AUTO_UPDATE, |     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||||
|     RAG_OPENAI_API_BASE_URL, |     RAG_OPENAI_API_BASE_URL, | ||||||
|     RAG_OPENAI_API_KEY, |     RAG_OPENAI_API_KEY, | ||||||
|     DEVICE_TYPE, |     DEVICE_TYPE, | ||||||
|  | @ -101,15 +98,12 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY | ||||||
| 
 | 
 | ||||||
| app.state.PDF_EXTRACT_IMAGES = False | app.state.PDF_EXTRACT_IMAGES = False | ||||||
| 
 | 
 | ||||||
| 
 | if app.state.RAG_EMBEDDING_ENGINE == "": | ||||||
| app.state.sentence_transformer_ef = ( |     app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( | ||||||
|     embedding_functions.SentenceTransformerEmbeddingFunction( |         app.state.RAG_EMBEDDING_MODEL, | ||||||
|         model_name=get_embedding_model_path( |  | ||||||
|             app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE |  | ||||||
|         ), |  | ||||||
|         device=DEVICE_TYPE, |         device=DEVICE_TYPE, | ||||||
|  |         trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||||
|     ) |     ) | ||||||
| ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| origins = ["*"] | origins = ["*"] | ||||||
|  | @ -185,13 +179,10 @@ async def update_embedding_config( | ||||||
|                 app.state.OPENAI_API_BASE_URL = form_data.openai_config.url |                 app.state.OPENAI_API_BASE_URL = form_data.openai_config.url | ||||||
|                 app.state.OPENAI_API_KEY = form_data.openai_config.key |                 app.state.OPENAI_API_KEY = form_data.openai_config.key | ||||||
|         else: |         else: | ||||||
|             sentence_transformer_ef = ( |             sentence_transformer_ef = sentence_transformers.SentenceTransformer( | ||||||
|                 embedding_functions.SentenceTransformerEmbeddingFunction( |                 app.state.RAG_EMBEDDING_MODEL, | ||||||
|                     model_name=get_embedding_model_path( |                 device=DEVICE_TYPE, | ||||||
|                         form_data.embedding_model, True |                 trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||||
|                     ), |  | ||||||
|                     device=DEVICE_TYPE, |  | ||||||
|                 ) |  | ||||||
|             ) |             ) | ||||||
|             app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model |             app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||||
|             app.state.sentence_transformer_ef = sentence_transformer_ef |             app.state.sentence_transformer_ef = sentence_transformer_ef | ||||||
|  | @ -294,38 +285,34 @@ def query_doc_handler( | ||||||
|     form_data: QueryDocForm, |     form_data: QueryDocForm, | ||||||
|     user=Depends(get_current_user), |     user=Depends(get_current_user), | ||||||
| ): | ): | ||||||
| 
 |  | ||||||
|     try: |     try: | ||||||
|         if app.state.RAG_EMBEDDING_ENGINE == "": |         if app.state.RAG_EMBEDDING_ENGINE == "": | ||||||
|             return query_doc( |             query_embeddings = app.state.sentence_transformer_ef.encode( | ||||||
|                 collection_name=form_data.collection_name, |                 form_data.query | ||||||
|                 query=form_data.query, |             ).tolist() | ||||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, |         elif app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||||
|                 embedding_function=app.state.sentence_transformer_ef, |             query_embeddings = generate_ollama_embeddings( | ||||||
|  |                 GenerateEmbeddingsForm( | ||||||
|  |                     **{ | ||||||
|  |                         "model": app.state.RAG_EMBEDDING_MODEL, | ||||||
|  |                         "prompt": form_data.query, | ||||||
|  |                     } | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         elif app.state.RAG_EMBEDDING_ENGINE == "openai": | ||||||
|  |             query_embeddings = generate_openai_embeddings( | ||||||
|  |                 model=app.state.RAG_EMBEDDING_MODEL, | ||||||
|  |                 text=form_data.query, | ||||||
|  |                 key=app.state.OPENAI_API_KEY, | ||||||
|  |                 url=app.state.OPENAI_API_BASE_URL, | ||||||
|             ) |             ) | ||||||
|         else: |  | ||||||
|             if app.state.RAG_EMBEDDING_ENGINE == "ollama": |  | ||||||
|                 query_embeddings = generate_ollama_embeddings( |  | ||||||
|                     GenerateEmbeddingsForm( |  | ||||||
|                         **{ |  | ||||||
|                             "model": app.state.RAG_EMBEDDING_MODEL, |  | ||||||
|                             "prompt": form_data.query, |  | ||||||
|                         } |  | ||||||
|                     ) |  | ||||||
|                 ) |  | ||||||
|             elif app.state.RAG_EMBEDDING_ENGINE == "openai": |  | ||||||
|                 query_embeddings = generate_openai_embeddings( |  | ||||||
|                     model=app.state.RAG_EMBEDDING_MODEL, |  | ||||||
|                     text=form_data.query, |  | ||||||
|                     key=app.state.OPENAI_API_KEY, |  | ||||||
|                     url=app.state.OPENAI_API_BASE_URL, |  | ||||||
|                 ) |  | ||||||
| 
 | 
 | ||||||
|             return query_embeddings_doc( |         return query_embeddings_doc( | ||||||
|                 collection_name=form_data.collection_name, |             collection_name=form_data.collection_name, | ||||||
|                 query_embeddings=query_embeddings, |             query=form_data.query, | ||||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, |             query_embeddings=query_embeddings, | ||||||
|             ) |             k=form_data.k if form_data.k else app.state.TOP_K, | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         log.exception(e) |         log.exception(e) | ||||||
|  | @ -348,36 +335,31 @@ def query_collection_handler( | ||||||
| ): | ): | ||||||
|     try: |     try: | ||||||
|         if app.state.RAG_EMBEDDING_ENGINE == "": |         if app.state.RAG_EMBEDDING_ENGINE == "": | ||||||
|             return query_collection( |             query_embeddings = app.state.sentence_transformer_ef.encode( | ||||||
|                 collection_names=form_data.collection_names, |                 form_data.query | ||||||
|                 query=form_data.query, |             ).tolist() | ||||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, |         elif app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||||
|                 embedding_function=app.state.sentence_transformer_ef, |             query_embeddings = generate_ollama_embeddings( | ||||||
|             ) |                 GenerateEmbeddingsForm( | ||||||
|         else: |                     **{ | ||||||
| 
 |                         "model": app.state.RAG_EMBEDDING_MODEL, | ||||||
|             if app.state.RAG_EMBEDDING_ENGINE == "ollama": |                         "prompt": form_data.query, | ||||||
|                 query_embeddings = generate_ollama_embeddings( |                     } | ||||||
|                     GenerateEmbeddingsForm( |  | ||||||
|                         **{ |  | ||||||
|                             "model": app.state.RAG_EMBEDDING_MODEL, |  | ||||||
|                             "prompt": form_data.query, |  | ||||||
|                         } |  | ||||||
|                     ) |  | ||||||
|                 ) |                 ) | ||||||
|             elif app.state.RAG_EMBEDDING_ENGINE == "openai": |  | ||||||
|                 query_embeddings = generate_openai_embeddings( |  | ||||||
|                     model=app.state.RAG_EMBEDDING_MODEL, |  | ||||||
|                     text=form_data.query, |  | ||||||
|                     key=app.state.OPENAI_API_KEY, |  | ||||||
|                     url=app.state.OPENAI_API_BASE_URL, |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|             return query_embeddings_collection( |  | ||||||
|                 collection_names=form_data.collection_names, |  | ||||||
|                 query_embeddings=query_embeddings, |  | ||||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, |  | ||||||
|             ) |             ) | ||||||
|  |         elif app.state.RAG_EMBEDDING_ENGINE == "openai": | ||||||
|  |             query_embeddings = generate_openai_embeddings( | ||||||
|  |                 model=app.state.RAG_EMBEDDING_MODEL, | ||||||
|  |                 text=form_data.query, | ||||||
|  |                 key=app.state.OPENAI_API_KEY, | ||||||
|  |                 url=app.state.OPENAI_API_BASE_URL, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         return query_embeddings_collection( | ||||||
|  |             collection_names=form_data.collection_names, | ||||||
|  |             query_embeddings=query_embeddings, | ||||||
|  |             k=form_data.k if form_data.k else app.state.TOP_K, | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         log.exception(e) |         log.exception(e) | ||||||
|  | @ -445,6 +427,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b | ||||||
|     log.info(f"store_docs_in_vector_db {docs} {collection_name}") |     log.info(f"store_docs_in_vector_db {docs} {collection_name}") | ||||||
| 
 | 
 | ||||||
|     texts = [doc.page_content for doc in docs] |     texts = [doc.page_content for doc in docs] | ||||||
|  |     texts = list(map(lambda x: x.replace("\n", " "), texts)) | ||||||
|  | 
 | ||||||
|     metadatas = [doc.metadata for doc in docs] |     metadatas = [doc.metadata for doc in docs] | ||||||
| 
 | 
 | ||||||
|     try: |     try: | ||||||
|  | @ -454,52 +438,38 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b | ||||||
|                     log.info(f"deleting existing collection {collection_name}") |                     log.info(f"deleting existing collection {collection_name}") | ||||||
|                     CHROMA_CLIENT.delete_collection(name=collection_name) |                     CHROMA_CLIENT.delete_collection(name=collection_name) | ||||||
| 
 | 
 | ||||||
|  |         collection = CHROMA_CLIENT.create_collection(name=collection_name) | ||||||
|  | 
 | ||||||
|         if app.state.RAG_EMBEDDING_ENGINE == "": |         if app.state.RAG_EMBEDDING_ENGINE == "": | ||||||
| 
 |             embeddings = app.state.sentence_transformer_ef.encode(texts).tolist() | ||||||
|             collection = CHROMA_CLIENT.create_collection( |         elif app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||||
|                 name=collection_name, |             embeddings = [ | ||||||
|                 embedding_function=app.state.sentence_transformer_ef, |                 generate_ollama_embeddings( | ||||||
|             ) |                     GenerateEmbeddingsForm( | ||||||
| 
 |                         **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text} | ||||||
|             for batch in create_batches( |  | ||||||
|                 api=CHROMA_CLIENT, |  | ||||||
|                 ids=[str(uuid.uuid1()) for _ in texts], |  | ||||||
|                 metadatas=metadatas, |  | ||||||
|                 documents=texts, |  | ||||||
|             ): |  | ||||||
|                 collection.add(*batch) |  | ||||||
| 
 |  | ||||||
|         else: |  | ||||||
|             collection = CHROMA_CLIENT.create_collection(name=collection_name) |  | ||||||
| 
 |  | ||||||
|             if app.state.RAG_EMBEDDING_ENGINE == "ollama": |  | ||||||
|                 embeddings = [ |  | ||||||
|                     generate_ollama_embeddings( |  | ||||||
|                         GenerateEmbeddingsForm( |  | ||||||
|                             **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text} |  | ||||||
|                         ) |  | ||||||
|                     ) |                     ) | ||||||
|                     for text in texts |                 ) | ||||||
|                 ] |                 for text in texts | ||||||
|             elif app.state.RAG_EMBEDDING_ENGINE == "openai": |             ] | ||||||
|                 embeddings = [ |         elif app.state.RAG_EMBEDDING_ENGINE == "openai": | ||||||
|                     generate_openai_embeddings( |             embeddings = [ | ||||||
|                         model=app.state.RAG_EMBEDDING_MODEL, |                 generate_openai_embeddings( | ||||||
|                         text=text, |                     model=app.state.RAG_EMBEDDING_MODEL, | ||||||
|                         key=app.state.OPENAI_API_KEY, |                     text=text, | ||||||
|                         url=app.state.OPENAI_API_BASE_URL, |                     key=app.state.OPENAI_API_KEY, | ||||||
|                     ) |                     url=app.state.OPENAI_API_BASE_URL, | ||||||
|                     for text in texts |                 ) | ||||||
|                 ] |                 for text in texts | ||||||
|  |             ] | ||||||
| 
 | 
 | ||||||
|             for batch in create_batches( |         for batch in create_batches( | ||||||
|                 api=CHROMA_CLIENT, |             api=CHROMA_CLIENT, | ||||||
|                 ids=[str(uuid.uuid1()) for _ in texts], |             ids=[str(uuid.uuid1()) for _ in texts], | ||||||
|                 metadatas=metadatas, |             metadatas=metadatas, | ||||||
|                 embeddings=embeddings, |             embeddings=embeddings, | ||||||
|                 documents=texts, |             documents=texts, | ||||||
|             ): |         ): | ||||||
|                 collection.add(*batch) |             collection.add(*batch) | ||||||
| 
 | 
 | ||||||
|         return True |         return True | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|  |  | ||||||
|  | @ -1,13 +1,12 @@ | ||||||
| import os |  | ||||||
| import re |  | ||||||
| import logging | import logging | ||||||
| from typing import List |  | ||||||
| import requests | import requests | ||||||
| 
 | 
 | ||||||
|  | from typing import List | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import snapshot_download | from apps.ollama.main import ( | ||||||
| from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm |     generate_ollama_embeddings, | ||||||
| 
 |     GenerateEmbeddingsForm, | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||||
| 
 | 
 | ||||||
|  | @ -16,29 +15,12 @@ log = logging.getLogger(__name__) | ||||||
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def query_doc(collection_name: str, query: str, k: int, embedding_function): | def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int): | ||||||
|     try: |  | ||||||
|         # if you use docker use the model from the environment variable |  | ||||||
|         collection = CHROMA_CLIENT.get_collection( |  | ||||||
|             name=collection_name, |  | ||||||
|             embedding_function=embedding_function, |  | ||||||
|         ) |  | ||||||
|         result = collection.query( |  | ||||||
|             query_texts=[query], |  | ||||||
|             n_results=k, |  | ||||||
|         ) |  | ||||||
|         return result |  | ||||||
|     except Exception as e: |  | ||||||
|         raise e |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def query_embeddings_doc(collection_name: str, query_embeddings, k: int): |  | ||||||
|     try: |     try: | ||||||
|         # if you use docker use the model from the environment variable |         # if you use docker use the model from the environment variable | ||||||
|         log.info(f"query_embeddings_doc {query_embeddings}") |         log.info(f"query_embeddings_doc {query_embeddings}") | ||||||
|         collection = CHROMA_CLIENT.get_collection( |         collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||||
|             name=collection_name, | 
 | ||||||
|         ) |  | ||||||
|         result = collection.query( |         result = collection.query( | ||||||
|             query_embeddings=[query_embeddings], |             query_embeddings=[query_embeddings], | ||||||
|             n_results=k, |             n_results=k, | ||||||
|  | @ -95,43 +77,20 @@ def merge_and_sort_query_results(query_results, k): | ||||||
|     return merged_query_results |     return merged_query_results | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def query_collection( | def query_embeddings_collection( | ||||||
|     collection_names: List[str], query: str, k: int, embedding_function |     collection_names: List[str], query: str, query_embeddings, k: int | ||||||
| ): | ): | ||||||
| 
 | 
 | ||||||
|     results = [] |  | ||||||
| 
 |  | ||||||
|     for collection_name in collection_names: |  | ||||||
|         try: |  | ||||||
|             # if you use docker use the model from the environment variable |  | ||||||
|             collection = CHROMA_CLIENT.get_collection( |  | ||||||
|                 name=collection_name, |  | ||||||
|                 embedding_function=embedding_function, |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|             result = collection.query( |  | ||||||
|                 query_texts=[query], |  | ||||||
|                 n_results=k, |  | ||||||
|             ) |  | ||||||
|             results.append(result) |  | ||||||
|         except: |  | ||||||
|             pass |  | ||||||
| 
 |  | ||||||
|     return merge_and_sort_query_results(results, k) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): |  | ||||||
| 
 |  | ||||||
|     results = [] |     results = [] | ||||||
|     log.info(f"query_embeddings_collection {query_embeddings}") |     log.info(f"query_embeddings_collection {query_embeddings}") | ||||||
| 
 | 
 | ||||||
|     for collection_name in collection_names: |     for collection_name in collection_names: | ||||||
|         try: |         try: | ||||||
|             collection = CHROMA_CLIENT.get_collection(name=collection_name) |             result = query_embeddings_doc( | ||||||
| 
 |                 collection_name=collection_name, | ||||||
|             result = collection.query( |                 query=query, | ||||||
|                 query_embeddings=[query_embeddings], |                 query_embeddings=query_embeddings, | ||||||
|                 n_results=k, |                 k=k, | ||||||
|             ) |             ) | ||||||
|             results.append(result) |             results.append(result) | ||||||
|         except: |         except: | ||||||
|  | @ -197,51 +156,38 @@ def rag_messages( | ||||||
|                 context = doc["content"] |                 context = doc["content"] | ||||||
|             else: |             else: | ||||||
|                 if embedding_engine == "": |                 if embedding_engine == "": | ||||||
|                     if doc["type"] == "collection": |                     query_embeddings = embedding_function.encode(query).tolist() | ||||||
|                         context = query_collection( |                 elif embedding_engine == "ollama": | ||||||
|                             collection_names=doc["collection_names"], |                     query_embeddings = generate_ollama_embeddings( | ||||||
|                             query=query, |                         GenerateEmbeddingsForm( | ||||||
|                             k=k, |                             **{ | ||||||
|                             embedding_function=embedding_function, |                                 "model": embedding_model, | ||||||
|                         ) |                                 "prompt": query, | ||||||
|                     else: |                             } | ||||||
|                         context = query_doc( |  | ||||||
|                             collection_name=doc["collection_name"], |  | ||||||
|                             query=query, |  | ||||||
|                             k=k, |  | ||||||
|                             embedding_function=embedding_function, |  | ||||||
|                         ) |                         ) | ||||||
|  |                     ) | ||||||
|  |                 elif embedding_engine == "openai": | ||||||
|  |                     query_embeddings = generate_openai_embeddings( | ||||||
|  |                         model=embedding_model, | ||||||
|  |                         text=query, | ||||||
|  |                         key=openai_key, | ||||||
|  |                         url=openai_url, | ||||||
|  |                     ) | ||||||
| 
 | 
 | ||||||
|  |                 if doc["type"] == "collection": | ||||||
|  |                     context = query_embeddings_collection( | ||||||
|  |                         collection_names=doc["collection_names"], | ||||||
|  |                         query=query, | ||||||
|  |                         query_embeddings=query_embeddings, | ||||||
|  |                         k=k, | ||||||
|  |                     ) | ||||||
|                 else: |                 else: | ||||||
|                     if embedding_engine == "ollama": |                     context = query_embeddings_doc( | ||||||
|                         query_embeddings = generate_ollama_embeddings( |                         collection_name=doc["collection_name"], | ||||||
|                             GenerateEmbeddingsForm( |                         query=query, | ||||||
|                                 **{ |                         query_embeddings=query_embeddings, | ||||||
|                                     "model": embedding_model, |                         k=k, | ||||||
|                                     "prompt": query, |                     ) | ||||||
|                                 } |  | ||||||
|                             ) |  | ||||||
|                         ) |  | ||||||
|                     elif embedding_engine == "openai": |  | ||||||
|                         query_embeddings = generate_openai_embeddings( |  | ||||||
|                             model=embedding_model, |  | ||||||
|                             text=query, |  | ||||||
|                             key=openai_key, |  | ||||||
|                             url=openai_url, |  | ||||||
|                         ) |  | ||||||
| 
 |  | ||||||
|                     if doc["type"] == "collection": |  | ||||||
|                         context = query_embeddings_collection( |  | ||||||
|                             collection_names=doc["collection_names"], |  | ||||||
|                             query_embeddings=query_embeddings, |  | ||||||
|                             k=k, |  | ||||||
|                         ) |  | ||||||
|                     else: |  | ||||||
|                         context = query_embeddings_doc( |  | ||||||
|                             collection_name=doc["collection_name"], |  | ||||||
|                             query_embeddings=query_embeddings, |  | ||||||
|                             k=k, |  | ||||||
|                         ) |  | ||||||
| 
 | 
 | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             log.exception(e) |             log.exception(e) | ||||||
|  | @ -283,46 +229,6 @@ def rag_messages( | ||||||
|     return messages |     return messages | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_embedding_model_path( |  | ||||||
|     embedding_model: str, update_embedding_model: bool = False |  | ||||||
| ): |  | ||||||
|     # Construct huggingface_hub kwargs with local_files_only to return the snapshot path |  | ||||||
|     cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") |  | ||||||
| 
 |  | ||||||
|     local_files_only = not update_embedding_model |  | ||||||
| 
 |  | ||||||
|     snapshot_kwargs = { |  | ||||||
|         "cache_dir": cache_dir, |  | ||||||
|         "local_files_only": local_files_only, |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     log.debug(f"embedding_model: {embedding_model}") |  | ||||||
|     log.debug(f"snapshot_kwargs: {snapshot_kwargs}") |  | ||||||
| 
 |  | ||||||
|     # Inspiration from upstream sentence_transformers |  | ||||||
|     if ( |  | ||||||
|         os.path.exists(embedding_model) |  | ||||||
|         or ("\\" in embedding_model or embedding_model.count("/") > 1) |  | ||||||
|         and local_files_only |  | ||||||
|     ): |  | ||||||
|         # If fully qualified path exists, return input, else set repo_id |  | ||||||
|         return embedding_model |  | ||||||
|     elif "/" not in embedding_model: |  | ||||||
|         # Set valid repo_id for model short-name |  | ||||||
|         embedding_model = "sentence-transformers" + "/" + embedding_model |  | ||||||
| 
 |  | ||||||
|     snapshot_kwargs["repo_id"] = embedding_model |  | ||||||
| 
 |  | ||||||
|     # Attempt to query the huggingface_hub library to determine the local path and/or to update |  | ||||||
|     try: |  | ||||||
|         embedding_model_repo_path = snapshot_download(**snapshot_kwargs) |  | ||||||
|         log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}") |  | ||||||
|         return embedding_model_repo_path |  | ||||||
|     except Exception as e: |  | ||||||
|         log.exception(f"Cannot determine embedding model snapshot path: {e}") |  | ||||||
|         return embedding_model |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def generate_openai_embeddings( | def generate_openai_embeddings( | ||||||
|     model: str, text: str, key: str, url: str = "https://api.openai.com/v1" |     model: str, text: str, key: str, url: str = "https://api.openai.com/v1" | ||||||
| ): | ): | ||||||
|  |  | ||||||
|  | @ -418,18 +418,19 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": | ||||||
| #################################### | #################################### | ||||||
| 
 | 
 | ||||||
| CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" | CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" | ||||||
| # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) | # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) | ||||||
| 
 | 
 | ||||||
| RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") | RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") | ||||||
| 
 | 
 | ||||||
| RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") | RAG_EMBEDDING_MODEL = os.environ.get( | ||||||
|  |     "RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2" | ||||||
|  | ) | ||||||
| log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), | log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), | ||||||
| 
 | 
 | ||||||
| RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( | RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( | ||||||
|     os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" |     os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance | # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance | ||||||
| USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") | USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -27,6 +27,7 @@ apscheduler | ||||||
| google-generativeai | google-generativeai | ||||||
| 
 | 
 | ||||||
| langchain | langchain | ||||||
|  | langchain-chroma | ||||||
| langchain-community | langchain-community | ||||||
| fake_useragent | fake_useragent | ||||||
| chromadb | chromadb | ||||||
|  | @ -45,6 +46,7 @@ opencv-python-headless | ||||||
| rapidocr-onnxruntime | rapidocr-onnxruntime | ||||||
| 
 | 
 | ||||||
| fpdf2 | fpdf2 | ||||||
|  | rank_bm25 | ||||||
| 
 | 
 | ||||||
| faster-whisper | faster-whisper | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -180,7 +180,7 @@ | ||||||
| 							} | 							} | ||||||
| 						}} | 						}} | ||||||
| 					> | 					> | ||||||
| 						<option value="">{$i18n.t('Default (SentenceTransformer)')}</option> | 						<option value="">{$i18n.t('Default (SentenceTransformers)')}</option> | ||||||
| 						<option value="ollama">{$i18n.t('Ollama')}</option> | 						<option value="ollama">{$i18n.t('Ollama')}</option> | ||||||
| 						<option value="openai">{$i18n.t('OpenAI')}</option> | 						<option value="openai">{$i18n.t('OpenAI')}</option> | ||||||
| 					</select> | 					</select> | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy Jaeryang Baek
						Timothy Jaeryang Baek