Merge branch 'dev' into feat/Teams_Incoming_Webhook

This commit is contained in:
Chris 2024-03-26 16:59:42 +08:00 committed by GitHub
commit 0b62bbb52e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
45 changed files with 1400 additions and 335 deletions

2
.gitignore vendored
View file

@ -166,7 +166,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ .idea/
# Logs # Logs
logs logs

View file

@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.115] - 2024-03-24
### Added
- **🔍 Custom Model Selector**: Easily find and select custom models with the new search filter feature.
- **🛑 Cancel Model Download**: Added the ability to cancel model downloads.
- **🎨 Image Generation ComfyUI**: Image generation now supports ComfyUI.
- **🌟 Updated Light Theme**: Updated the light theme for a fresh look.
- **🌍 Additional Language Support**: Now supporting Bulgarian, Italian, Portuguese, Japanese, and Dutch.
### Fixed
- **🔧 Fixed Broken Experimental GGUF Upload**: Resolved issues with experimental GGUF upload functionality.
### Changed
- **🔄 Vector Storage Reset Button**: Moved the reset vector storage button to document settings.
## [0.1.114] - 2024-03-20 ## [0.1.114] - 2024-03-20
### Added ### Added

View file

@ -1,4 +1,5 @@
import os import os
import logging
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
Request, Request,
@ -21,7 +22,10 @@ from utils.utils import (
) )
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR from config import SRC_LOG_LEVELS, CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
@ -38,7 +42,7 @@ def transcribe(
file: UploadFile = File(...), file: UploadFile = File(...),
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
print(file.content_type) log.info(f"file.content_type: {file.content_type}")
if file.content_type not in ["audio/mpeg", "audio/wav"]: if file.content_type not in ["audio/mpeg", "audio/wav"]:
raise HTTPException( raise HTTPException(
@ -62,7 +66,7 @@ def transcribe(
) )
segments, info = model.transcribe(file_path, beam_size=5) segments, info = model.transcribe(file_path, beam_size=5)
print( log.info(
"Detected language '%s' with probability %f" "Detected language '%s' with probability %f"
% (info.language, info.language_probability) % (info.language, info.language_probability)
) )
@ -72,7 +76,7 @@ def transcribe(
return {"text": transcript.strip()} return {"text": transcript.strip()}
except Exception as e: except Exception as e:
print(e) log.exception(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,

View file

@ -27,10 +27,14 @@ from pathlib import Path
import uuid import uuid
import base64 import base64
import json import json
import logging
from config import CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL from config import SRC_LOG_LEVELS, CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
@ -304,7 +308,7 @@ def save_b64_image(b64_str):
return image_id return image_id
except Exception as e: except Exception as e:
print(f"Error saving image: {e}") log.error(f"Error saving image: {e}")
return None return None
@ -431,7 +435,7 @@ def generate_image(
res = r.json() res = r.json()
print(res) log.debug(f"res: {res}")
images = [] images = []

View file

@ -1,3 +1,5 @@
import logging
from litellm.proxy.proxy_server import ProxyConfig, initialize from litellm.proxy.proxy_server import ProxyConfig, initialize
from litellm.proxy.proxy_server import app from litellm.proxy.proxy_server import app
@ -9,7 +11,10 @@ from starlette.responses import StreamingResponse
import json import json
from utils.utils import get_http_authorization_cred, get_current_user from utils.utils import get_http_authorization_cred, get_current_user
from config import ENV from config import SRC_LOG_LEVELS, ENV
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import ( from config import (
@ -49,7 +54,7 @@ async def auth_middleware(request: Request, call_next):
try: try:
user = get_current_user(get_http_authorization_cred(auth_header)) user = get_current_user(get_http_authorization_cred(auth_header))
print(user) log.debug(f"user: {user}")
request.state.user = user request.state.user = user
except Exception as e: except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)}) return JSONResponse(status_code=400, content={"detail": str(e)})

View file

@ -23,6 +23,7 @@ import json
import uuid import uuid
import aiohttp import aiohttp
import asyncio import asyncio
import logging
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Optional, List, Union from typing import Optional, List, Union
@ -30,11 +31,13 @@ from typing import Optional, List, Union
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user, get_admin_user from utils.utils import decode_token, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS, OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, UPLOAD_DIR
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
log = logging.getLogger(__name__)
from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, UPLOAD_DIR log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
@ -85,7 +88,7 @@ class UrlUpdateForm(BaseModel):
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
app.state.OLLAMA_BASE_URLS = form_data.urls app.state.OLLAMA_BASE_URLS = form_data.urls
print(app.state.OLLAMA_BASE_URLS) log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}")
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
@ -106,7 +109,7 @@ async def fetch_url(url):
return await response.json() return await response.json()
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") log.error(f"Connection error: {e}")
return None return None
@ -130,7 +133,7 @@ def merge_models_lists(model_lists):
async def get_all_models(): async def get_all_models():
print("get_all_models") log.info("get_all_models()")
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
@ -171,7 +174,7 @@ async def get_ollama_tags(
return r.json() return r.json()
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
@ -217,7 +220,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
return r.json() return r.json()
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
@ -243,18 +246,33 @@ async def pull_model(
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
): ):
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
r = None r = None
def get_request(): def get_request():
nonlocal url nonlocal url
nonlocal r nonlocal r
request_id = str(uuid.uuid4())
try: try:
REQUEST_POOL.append(request_id)
def stream_content(): def stream_content():
for chunk in r.iter_content(chunk_size=8192): try:
yield chunk yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
print("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request( r = requests.request(
method="POST", method="POST",
@ -275,8 +293,9 @@ async def pull_model(
try: try:
return await run_in_threadpool(get_request) return await run_in_threadpool(get_request)
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
@ -315,7 +334,7 @@ async def push_model(
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.debug(f"url: {url}")
r = None r = None
@ -347,7 +366,7 @@ async def push_model(
try: try:
return await run_in_threadpool(get_request) return await run_in_threadpool(get_request)
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
@ -375,9 +394,9 @@ class CreateModelForm(BaseModel):
async def create_model( async def create_model(
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
): ):
print(form_data) log.debug(f"form_data: {form_data}")
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
r = None r = None
@ -399,7 +418,7 @@ async def create_model(
r.raise_for_status() r.raise_for_status()
print(r) log.debug(f"r: {r}")
return StreamingResponse( return StreamingResponse(
stream_content(), stream_content(),
@ -412,7 +431,7 @@ async def create_model(
try: try:
return await run_in_threadpool(get_request) return await run_in_threadpool(get_request)
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
@ -450,7 +469,7 @@ async def copy_model(
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
try: try:
r = requests.request( r = requests.request(
@ -460,11 +479,11 @@ async def copy_model(
) )
r.raise_for_status() r.raise_for_status()
print(r.text) log.debug(f"r.text: {r.text}")
return True return True
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
@ -497,7 +516,7 @@ async def delete_model(
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
try: try:
r = requests.request( r = requests.request(
@ -507,11 +526,11 @@ async def delete_model(
) )
r.raise_for_status() r.raise_for_status()
print(r.text) log.debug(f"r.text: {r.text}")
return True return True
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
@ -537,7 +556,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
try: try:
r = requests.request( r = requests.request(
@ -549,7 +568,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
return r.json() return r.json()
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
@ -589,7 +608,7 @@ async def generate_embeddings(
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
try: try:
r = requests.request( r = requests.request(
@ -601,7 +620,7 @@ async def generate_embeddings(
return r.json() return r.json()
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
@ -649,7 +668,7 @@ async def generate_completion(
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
r = None r = None
@ -670,7 +689,7 @@ async def generate_completion(
if request_id in REQUEST_POOL: if request_id in REQUEST_POOL:
yield chunk yield chunk
else: else:
print("User: canceled request") log.warning("User: canceled request")
break break
finally: finally:
if hasattr(r, "close"): if hasattr(r, "close"):
@ -747,11 +766,11 @@ async def generate_chat_completion(
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
r = None r = None
print(form_data.model_dump_json(exclude_none=True).encode()) log.debug("form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(form_data.model_dump_json(exclude_none=True).encode()))
def get_request(): def get_request():
nonlocal form_data nonlocal form_data
@ -770,7 +789,7 @@ async def generate_chat_completion(
if request_id in REQUEST_POOL: if request_id in REQUEST_POOL:
yield chunk yield chunk
else: else:
print("User: canceled request") log.warning("User: canceled request")
break break
finally: finally:
if hasattr(r, "close"): if hasattr(r, "close"):
@ -793,7 +812,7 @@ async def generate_chat_completion(
headers=dict(r.headers), headers=dict(r.headers),
) )
except Exception as e: except Exception as e:
print(e) log.exception(e)
raise e raise e
try: try:
@ -847,7 +866,7 @@ async def generate_openai_chat_completion(
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
r = None r = None
@ -870,7 +889,7 @@ async def generate_openai_chat_completion(
if request_id in REQUEST_POOL: if request_id in REQUEST_POOL:
yield chunk yield chunk
else: else:
print("User: canceled request") log.warning("User: canceled request")
break break
finally: finally:
if hasattr(r, "close"): if hasattr(r, "close"):
@ -1168,7 +1187,7 @@ async def deprecated_proxy(path: str, request: Request, user=Depends(get_current
if request_id in REQUEST_POOL: if request_id in REQUEST_POOL:
yield chunk yield chunk
else: else:
print("User: canceled request") log.warning("User: canceled request")
break break
finally: finally:
if hasattr(r, "close"): if hasattr(r, "close"):

View file

@ -6,6 +6,7 @@ import requests
import aiohttp import aiohttp
import asyncio import asyncio
import json import json
import logging
from pydantic import BaseModel from pydantic import BaseModel
@ -19,6 +20,7 @@ from utils.utils import (
get_admin_user, get_admin_user,
) )
from config import ( from config import (
SRC_LOG_LEVELS,
OPENAI_API_BASE_URLS, OPENAI_API_BASE_URLS,
OPENAI_API_KEYS, OPENAI_API_KEYS,
CACHE_DIR, CACHE_DIR,
@ -31,6 +33,9 @@ from typing import List, Optional
import hashlib import hashlib
from pathlib import Path from pathlib import Path
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -134,7 +139,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
@ -160,7 +165,7 @@ async def fetch_url(url, key):
return await response.json() return await response.json()
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") log.error(f"Connection error: {e}")
return None return None
@ -182,7 +187,7 @@ def merge_models_lists(model_lists):
async def get_all_models(): async def get_all_models():
print("get_all_models") log.info("get_all_models()")
if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "": if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "":
models = {"data": []} models = {"data": []}
@ -208,7 +213,7 @@ async def get_all_models():
) )
} }
print(models) log.info(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]} app.state.MODELS = {model["id"]: model for model in models["data"]}
return models return models
@ -246,7 +251,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return response_data return response_data
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
@ -280,7 +285,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if body.get("model") == "gpt-4-vision-preview": if body.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in body: if "max_tokens" not in body:
body["max_tokens"] = 4000 body["max_tokens"] = 4000
print("Modified body_dict:", body) log.debug("Modified body_dict:", body)
# Fix for ChatGPT calls failing because the num_ctx key is in body # Fix for ChatGPT calls failing because the num_ctx key is in body
if "num_ctx" in body: if "num_ctx" in body:
@ -292,7 +297,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
# Convert the modified body back to JSON # Convert the modified body back to JSON
body = json.dumps(body) body = json.dumps(body)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print("Error loading request body into a dictionary:", e) log.error("Error loading request body into a dictionary:", e)
url = app.state.OPENAI_API_BASE_URLS[idx] url = app.state.OPENAI_API_BASE_URLS[idx]
key = app.state.OPENAI_API_KEYS[idx] key = app.state.OPENAI_API_KEYS[idx]
@ -330,7 +335,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
response_data = r.json() response_data = r.json()
return response_data return response_data
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:

View file

@ -8,7 +8,7 @@ from fastapi import (
Form, Form,
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import os, shutil import os, shutil, logging
from pathlib import Path from pathlib import Path
from typing import List from typing import List
@ -21,6 +21,7 @@ from langchain_community.document_loaders import (
TextLoader, TextLoader,
PyPDFLoader, PyPDFLoader,
CSVLoader, CSVLoader,
BSHTMLLoader,
Docx2txtLoader, Docx2txtLoader,
UnstructuredEPubLoader, UnstructuredEPubLoader,
UnstructuredWordDocumentLoader, UnstructuredWordDocumentLoader,
@ -54,6 +55,7 @@ from utils.misc import (
) )
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
from config import ( from config import (
SRC_LOG_LEVELS,
UPLOAD_DIR, UPLOAD_DIR,
DOCS_DIR, DOCS_DIR,
RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL,
@ -66,6 +68,9 @@ from config import (
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
# #
# if RAG_EMBEDDING_MODEL: # if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer( # sentence_transformer_ef = SentenceTransformer(
@ -111,39 +116,6 @@ class StoreWebForm(CollectionNameForm):
url: str url: str
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
)
docs = text_splitter.split_documents(data)
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
try:
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
if collection_name == collection.name:
print(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
return True
except Exception as e:
print(e)
if e.__class__.__name__ == "UniqueConstraintError":
return True
return False
@app.get("/") @app.get("/")
async def get_status(): async def get_status():
return { return {
@ -274,7 +246,7 @@ def query_doc_handler(
embedding_function=app.state.sentence_transformer_ef, embedding_function=app.state.sentence_transformer_ef,
) )
except Exception as e: except Exception as e:
print(e) log.exception(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),
@ -318,13 +290,69 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
"filename": form_data.url, "filename": form_data.url,
} }
except Exception as e: except Exception as e:
print(e) log.exception(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),
) )
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.split_documents(data)
if len(docs) > 0:
return store_docs_in_vector_db(docs, collection_name, overwrite), None
else:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
def store_text_in_vector_db(
text, metadata, collection_name, overwrite: bool = False
) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.create_documents([text], metadatas=[metadata])
return store_docs_in_vector_db(docs, collection_name, overwrite)
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
try:
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
if collection_name == collection.name:
print(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
return True
except Exception as e:
print(e)
if e.__class__.__name__ == "UniqueConstraintError":
return True
return False
def get_loader(filename: str, file_content_type: str, file_path: str): def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower() file_ext = filename.split(".")[-1].lower()
known_type = True known_type = True
@ -382,6 +410,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
loader = UnstructuredRSTLoader(file_path, mode="elements") loader = UnstructuredRSTLoader(file_path, mode="elements")
elif file_ext == "xml": elif file_ext == "xml":
loader = UnstructuredXMLLoader(file_path) loader = UnstructuredXMLLoader(file_path)
elif file_ext in ["htm", "html"]:
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
elif file_ext == "md": elif file_ext == "md":
loader = UnstructuredMarkdownLoader(file_path) loader = UnstructuredMarkdownLoader(file_path)
elif file_content_type == "application/epub+zip": elif file_content_type == "application/epub+zip":
@ -416,7 +446,7 @@ def store_doc(
): ):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
print(file.content_type) log.info(f"file.content_type: {file.content_type}")
try: try:
filename = file.filename filename = file.filename
file_path = f"{UPLOAD_DIR}/{filename}" file_path = f"{UPLOAD_DIR}/{filename}"
@ -432,22 +462,24 @@ def store_doc(
loader, known_type = get_loader(file.filename, file.content_type, file_path) loader, known_type = get_loader(file.filename, file.content_type, file_path)
data = loader.load() data = loader.load()
result = store_data_in_vector_db(data, collection_name)
if result: try:
return { result = store_data_in_vector_db(data, collection_name)
"status": True,
"collection_name": collection_name, if result:
"filename": filename, return {
"known_type": known_type, "status": True,
} "collection_name": collection_name,
else: "filename": filename,
"known_type": known_type,
}
except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(), detail=e,
) )
except Exception as e: except Exception as e:
print(e) log.exception(e)
if "No pandoc was found" in str(e): if "No pandoc was found" in str(e):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -460,6 +492,37 @@ def store_doc(
) )
class TextRAGForm(BaseModel):
name: str
content: str
collection_name: Optional[str] = None
@app.post("/text")
def store_text(
form_data: TextRAGForm,
user=Depends(get_current_user),
):
collection_name = form_data.collection_name
if collection_name == None:
collection_name = calculate_sha256_string(form_data.content)
result = store_text_in_vector_db(
form_data.content,
metadata={"name": form_data.name, "created_by": user.id},
collection_name=collection_name,
)
if result:
return {"status": True, "collection_name": collection_name}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
@app.get("/scan") @app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)): def scan_docs_dir(user=Depends(get_admin_user)):
for path in Path(DOCS_DIR).rglob("./**/*"): for path in Path(DOCS_DIR).rglob("./**/*"):
@ -478,41 +541,45 @@ def scan_docs_dir(user=Depends(get_admin_user)):
) )
data = loader.load() data = loader.load()
result = store_data_in_vector_db(data, collection_name) try:
result = store_data_in_vector_db(data, collection_name)
if result: if result:
sanitized_filename = sanitize_filename(filename) sanitized_filename = sanitize_filename(filename)
doc = Documents.get_doc_by_name(sanitized_filename) doc = Documents.get_doc_by_name(sanitized_filename)
if doc == None: if doc == None:
doc = Documents.insert_new_doc( doc = Documents.insert_new_doc(
user.id, user.id,
DocumentForm( DocumentForm(
**{ **{
"name": sanitized_filename, "name": sanitized_filename,
"title": filename, "title": filename,
"collection_name": collection_name, "collection_name": collection_name,
"filename": filename, "filename": filename,
"content": ( "content": (
json.dumps( json.dumps(
{ {
"tags": list( "tags": list(
map( map(
lambda name: {"name": name}, lambda name: {"name": name},
tags, tags,
)
) )
) }
} )
) if len(tags)
if len(tags) else "{}"
else "{}" ),
), }
} ),
), )
) except Exception as e:
print(e)
pass
except Exception as e: except Exception as e:
print(e) log.exception(e)
return True return True
@ -533,11 +600,11 @@ def reset(user=Depends(get_admin_user)) -> bool:
elif os.path.isdir(file_path): elif os.path.isdir(file_path):
shutil.rmtree(file_path) shutil.rmtree(file_path)
except Exception as e: except Exception as e:
print("Failed to delete %s. Reason: %s" % (file_path, e)) log.error("Failed to delete %s. Reason: %s" % (file_path, e))
try: try:
CHROMA_CLIENT.reset() CHROMA_CLIENT.reset()
except Exception as e: except Exception as e:
print(e) log.exception(e)
return True return True

View file

@ -1,7 +1,11 @@
import re import re
import logging
from typing import List from typing import List
from config import CHROMA_CLIENT from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def query_doc(collection_name: str, query: str, k: int, embedding_function): def query_doc(collection_name: str, query: str, k: int, embedding_function):
@ -97,7 +101,7 @@ def rag_template(template: str, context: str, query: str):
def rag_messages(docs, messages, template, k, embedding_function): def rag_messages(docs, messages, template, k, embedding_function):
print(docs) log.debug(f"docs: {docs}")
last_user_message_idx = None last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1): for i in range(len(messages) - 1, -1, -1):
@ -137,6 +141,8 @@ def rag_messages(docs, messages, template, k, embedding_function):
k=k, k=k,
embedding_function=embedding_function, embedding_function=embedding_function,
) )
elif doc["type"] == "text":
context = doc["content"]
else: else:
context = query_doc( context = query_doc(
collection_name=doc["collection_name"], collection_name=doc["collection_name"],
@ -145,7 +151,7 @@ def rag_messages(docs, messages, template, k, embedding_function):
embedding_function=embedding_function, embedding_function=embedding_function,
) )
except Exception as e: except Exception as e:
print(e) log.exception(e)
context = None context = None
relevant_contexts.append(context) relevant_contexts.append(context)

View file

@ -1,13 +1,16 @@
from peewee import * from peewee import *
from config import DATA_DIR from config import SRC_LOG_LEVELS, DATA_DIR
import os import os
import logging
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
# Check if the file exists # Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"): if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file # Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
print("File renamed successfully.") log.info("File renamed successfully.")
else: else:
pass pass

View file

@ -2,6 +2,7 @@ from pydantic import BaseModel
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
import uuid import uuid
import logging
from peewee import * from peewee import *
from apps.web.models.users import UserModel, Users from apps.web.models.users import UserModel, Users
@ -9,6 +10,10 @@ from utils.utils import verify_password
from apps.web.internal.db import DB from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# DB MODEL # DB MODEL
#################### ####################
@ -86,7 +91,7 @@ class AuthsTable:
def insert_new_auth( def insert_new_auth(
self, email: str, password: str, name: str, role: str = "pending" self, email: str, password: str, name: str, role: str = "pending"
) -> Optional[UserModel]: ) -> Optional[UserModel]:
print("insert_new_auth") log.info("insert_new_auth")
id = str(uuid.uuid4()) id = str(uuid.uuid4())
@ -103,7 +108,7 @@ class AuthsTable:
return None return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
print("authenticate_user", email) log.info(f"authenticate_user: {email}")
try: try:
auth = Auth.get(Auth.email == email, Auth.active == True) auth = Auth.get(Auth.email == email, Auth.active == True)
if auth: if auth:

View file

@ -3,6 +3,7 @@ from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
import logging
from utils.utils import decode_token from utils.utils import decode_token
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
@ -11,6 +12,10 @@ from apps.web.internal.db import DB
import json import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# Documents DB Schema # Documents DB Schema
#################### ####################
@ -118,7 +123,7 @@ class DocumentsTable:
doc = Document.get(Document.name == form_data.name) doc = Document.get(Document.name == form_data.name)
return DocumentModel(**model_to_dict(doc)) return DocumentModel(**model_to_dict(doc))
except Exception as e: except Exception as e:
print(e) log.exception(e)
return None return None
def update_doc_content_by_name( def update_doc_content_by_name(
@ -138,7 +143,7 @@ class DocumentsTable:
doc = Document.get(Document.name == name) doc = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(doc)) return DocumentModel(**model_to_dict(doc))
except Exception as e: except Exception as e:
print(e) log.exception(e)
return None return None
def delete_doc_by_name(self, name: str) -> bool: def delete_doc_by_name(self, name: str) -> bool:

View file

@ -6,9 +6,14 @@ from playhouse.shortcuts import model_to_dict
import json import json
import uuid import uuid
import time import time
import logging
from apps.web.internal.db import DB from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# Tag DB Schema # Tag DB Schema
#################### ####################
@ -173,7 +178,7 @@ class TagTable:
(ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id) (ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)
) )
res = query.execute() # Remove the rows, return number of rows removed. res = query.execute() # Remove the rows, return number of rows removed.
print(res) log.debug(f"res: {res}")
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0: if tag_count == 0:
@ -185,7 +190,7 @@ class TagTable:
return True return True
except Exception as e: except Exception as e:
print("delete_tag", e) log.error(f"delete_tag: {e}")
return False return False
def delete_tag_by_tag_name_and_chat_id_and_user_id( def delete_tag_by_tag_name_and_chat_id_and_user_id(
@ -198,7 +203,7 @@ class TagTable:
& (ChatIdTag.user_id == user_id) & (ChatIdTag.user_id == user_id)
) )
res = query.execute() # Remove the rows, return number of rows removed. res = query.execute() # Remove the rows, return number of rows removed.
print(res) log.debug(f"res: {res}")
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0: if tag_count == 0:
@ -210,7 +215,7 @@ class TagTable:
return True return True
except Exception as e: except Exception as e:
print("delete_tag", e) log.error(f"delete_tag: {e}")
return False return False
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool: def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:

View file

@ -5,6 +5,7 @@ from utils.utils import get_current_user, get_admin_user
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
import logging
from apps.web.models.users import Users from apps.web.models.users import Users
from apps.web.models.chats import ( from apps.web.models.chats import (
@ -27,6 +28,10 @@ from apps.web.models.tags import (
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter() router = APIRouter()
############################ ############################
@ -78,7 +83,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
chat = Chats.insert_new_chat(user.id, form_data) chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
except Exception as e: except Exception as e:
print(e) log.exception(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
) )
@ -95,7 +100,7 @@ async def get_all_tags(user=Depends(get_current_user)):
tags = Tags.get_tags_by_user_id(user.id) tags = Tags.get_tags_by_user_id(user.id)
return tags return tags
except Exception as e: except Exception as e:
print(e) log.exception(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
) )

View file

@ -7,6 +7,7 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import time import time
import uuid import uuid
import logging
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths from apps.web.models.auths import Auths
@ -14,6 +15,10 @@ from apps.web.models.auths import Auths
from utils.utils import get_current_user, get_password_hash, get_admin_user from utils.utils import get_current_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter() router = APIRouter()
############################ ############################
@ -83,7 +88,7 @@ async def update_user_by_id(
if form_data.password: if form_data.password:
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
print(hashed) log.debug(f"hashed: {hashed}")
Auths.update_user_password_by_id(user_id, hashed) Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(user_id, form_data.email.lower()) Auths.update_email_by_id(user_id, form_data.email.lower())

View file

@ -1,24 +1,29 @@
import json
import os import os
import shutil import sys
from base64 import b64encode import logging
from pathlib import Path
from secrets import token_bytes
import chromadb import chromadb
from chromadb import Settings
from base64 import b64encode
from bs4 import BeautifulSoup
from pathlib import Path
import json
import yaml
import markdown import markdown
import requests import requests
import yaml import shutil
from bs4 import BeautifulSoup
from chromadb import Settings from secrets import token_bytes
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
try: try:
from dotenv import find_dotenv, load_dotenv from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv("../.env")) load_dotenv(find_dotenv("../.env"))
except ImportError: except ImportError:
print("dotenv not installed, skipping...") log.warning("dotenv not installed, skipping...")
WEBUI_NAME = "Open WebUI" WEBUI_NAME = "Open WebUI"
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
@ -98,6 +103,34 @@ for version in soup.find_all("h2"):
CHANGELOG = changelog_json CHANGELOG = changelog_json
####################################
# LOGGING
####################################
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
if GLOBAL_LOG_LEVEL in log_levels:
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
else:
GLOBAL_LOG_LEVEL = "INFO"
log = logging.getLogger(__name__)
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
log_sources = ["AUDIO", "CONFIG", "DB", "IMAGES", "LITELLM", "MAIN", "MODELS", "OLLAMA", "OPENAI", "RAG"]
SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in log_levels:
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
#################################### ####################################
# CUSTOM_NAME # CUSTOM_NAME
#################################### ####################################
@ -123,7 +156,7 @@ if CUSTOM_NAME:
WEBUI_NAME = data["name"] WEBUI_NAME = data["name"]
except Exception as e: except Exception as e:
print(e) log.exception(e)
pass pass
@ -192,9 +225,9 @@ def create_config_file(file_path):
LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml" LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml"
if not os.path.exists(LITELLM_CONFIG_PATH): if not os.path.exists(LITELLM_CONFIG_PATH):
print("Config file doesn't exist. Creating...") log.info("Config file doesn't exist. Creating...")
create_config_file(LITELLM_CONFIG_PATH) create_config_file(LITELLM_CONFIG_PATH)
print("Config file created successfully.") log.info("Config file created successfully.")
#################################### ####################################
@ -206,7 +239,7 @@ OLLAMA_API_BASE_URL = os.environ.get(
) )
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
KUBERNETES_SERVICE_HOST = os.environ.get("KUBERNETES_SERVICE_HOST", "")
if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "": if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "":
OLLAMA_BASE_URL = ( OLLAMA_BASE_URL = (
@ -216,8 +249,10 @@ if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "":
) )
if ENV == "prod": if ENV == "prod":
if OLLAMA_BASE_URL == "/ollama": if OLLAMA_BASE_URL == "/ollama" and KUBERNETES_SERVICE_HOST == "":
OLLAMA_BASE_URL = "http://host.docker.internal:11434" OLLAMA_BASE_URL = "http://host.docker.internal:11434"
else:
OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434"
OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")

View file

@ -60,3 +60,5 @@ class ERROR_MESSAGES(str, Enum):
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found" OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found"
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."

View file

@ -4,6 +4,7 @@ import markdown
import time import time
import os import os
import sys import sys
import logging
import requests import requests
from fastapi import FastAPI, Request, Depends, status from fastapi import FastAPI, Request, Depends, status
@ -38,10 +39,15 @@ from config import (
FRONTEND_BUILD_DIR, FRONTEND_BUILD_DIR,
MODEL_FILTER_ENABLED, MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
GLOBAL_LOG_LEVEL,
SRC_LOG_LEVELS,
WEBHOOK_URL, WEBHOOK_URL,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
class SPAStaticFiles(StaticFiles): class SPAStaticFiles(StaticFiles):
async def get_response(self, path: str, scope): async def get_response(self, path: str, scope):
@ -70,7 +76,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
if request.method == "POST" and ( if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path "/api/chat" in request.url.path or "/chat/completions" in request.url.path
): ):
print(request.url.path) log.debug(f"request.url.path: {request.url.path}")
# Read the original request body # Read the original request body
body = await request.body() body = await request.body()
@ -93,7 +99,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
) )
del data["docs"] del data["docs"]
print(data["messages"]) log.debug(f"data['messages']: {data['messages']}")
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")

View file

@ -35,4 +35,4 @@ spec:
volumes: volumes:
- name: webui-volume - name: webui-volume
persistentVolumeClaim: persistentVolumeClaim:
claimName: ollama-webui-pvc claimName: open-webui-pvc

View file

@ -2,8 +2,8 @@ apiVersion: v1
kind: PersistentVolumeClaim kind: PersistentVolumeClaim
metadata: metadata:
labels: labels:
app: ollama-webui app: open-webui
name: ollama-webui-pvc name: open-webui-pvc
namespace: open-webui namespace: open-webui
spec: spec:
accessModes: ["ReadWriteOnce"] accessModes: ["ReadWriteOnce"]

View file

@ -33,7 +33,7 @@ export const getLiteLLMModels = async (token: string = '') => {
id: model.id, id: model.id,
name: model.name ?? model.id, name: model.name ?? model.id,
external: true, external: true,
source: 'litellm' source: 'LiteLLM'
})) }))
.sort((a, b) => { .sort((a, b) => {
return a.name.localeCompare(b.name); return a.name.localeCompare(b.name);

View file

@ -271,7 +271,7 @@ export const generateChatCompletion = async (token: string = '', body: object) =
return [res, controller]; return [res, controller];
}; };
export const cancelChatCompletion = async (token: string = '', requestId: string) => { export const cancelOllamaRequest = async (token: string = '', requestId: string) => {
let error = null; let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, {

View file

@ -263,3 +263,53 @@ export const synthesizeOpenAISpeech = async (
return res; return res;
}; };
export const generateTitle = async (
token: string = '',
template: string,
model: string,
prompt: string,
url: string = OPENAI_API_BASE_URL
) => {
let error = null;
template = template.replace(/{{prompt}}/g, prompt);
console.log(template);
const res = await fetch(`${url}/chat/completions`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
model: model,
messages: [
{
role: 'user',
content: template
}
],
stream: false
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
}
return null;
});
if (error) {
throw error;
}
return res?.choices[0]?.message?.content ?? 'New Chat';
};

View file

@ -3,6 +3,7 @@
import { models, showSettings, settings, user } from '$lib/stores'; import { models, showSettings, settings, user } from '$lib/stores';
import { onMount, tick, getContext } from 'svelte'; import { onMount, tick, getContext } from 'svelte';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import Selector from './ModelSelector/Selector.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -32,30 +33,24 @@
} }
</script> </script>
<div class="flex flex-col my-2"> <div class="flex flex-col my-2 w-full">
{#each selectedModels as selectedModel, selectedModelIdx} {#each selectedModels as selectedModel, selectedModelIdx}
<div class="flex"> <div class="flex w-full">
<select <div class="overflow-hidden w-full">
id="models" <div class="mr-2 max-w-full">
class="outline-none bg-transparent text-lg font-semibold rounded-lg block w-full placeholder-gray-400" <Selector
bind:value={selectedModel} placeholder={$i18n.t('Select a model')}
{disabled} items={$models
> .filter((model) => model.name !== 'hr')
<option class=" text-gray-700" value="" selected disabled .map((model) => ({
>{$i18n.t('Select a model')}</option value: model.id,
> label: model.name,
info: model
{#each $models as model} }))}
{#if model.name === 'hr'} bind:value={selectedModel}
<hr /> />
{:else} </div>
<option value={model.id} class="text-gray-700 text-lg" </div>
>{model.name +
`${model.size ? ` (${(model.size / 1024 ** 3).toFixed(1)}GB)` : ''}`}</option
>
{/if}
{/each}
</select>
{#if selectedModelIdx === 0} {#if selectedModelIdx === 0}
<button <button
@ -136,6 +131,6 @@
{/each} {/each}
</div> </div>
<div class="text-left mt-1.5 text-xs text-gray-500"> <div class="text-left mt-1.5 ml-1 text-xs text-gray-500">
<button on:click={saveDefaultModel}> {$i18n.t('Set as default')}</button> <button on:click={saveDefaultModel}> {$i18n.t('Set as default')}</button>
</div> </div>

View file

@ -0,0 +1,389 @@
<script lang="ts">
import { Select } from 'bits-ui';
import { flyAndScale } from '$lib/utils/transitions';
import { createEventDispatcher, onMount, getContext, tick } from 'svelte';
import ChevronDown from '$lib/components/icons/ChevronDown.svelte';
import Check from '$lib/components/icons/Check.svelte';
import Search from '$lib/components/icons/Search.svelte';
import { cancelOllamaRequest, deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama';
import { user, MODEL_DOWNLOAD_POOL, models } from '$lib/stores';
import { toast } from 'svelte-sonner';
import { capitalizeFirstLetter, getModels, splitStream } from '$lib/utils';
import Tooltip from '$lib/components/common/Tooltip.svelte';
const i18n = getContext('i18n');
const dispatch = createEventDispatcher();
export let value = '';
export let placeholder = 'Select a model';
export let searchEnabled = true;
export let searchPlaceholder = 'Search a model';
export let items = [{ value: 'mango', label: 'Mango' }];
let searchValue = '';
let ollamaVersion = null;
$: filteredItems = searchValue
? items.filter((item) => item.value.includes(searchValue.toLowerCase()))
: items;
const pullModelHandler = async () => {
const sanitizedModelTag = searchValue.trim();
console.log($MODEL_DOWNLOAD_POOL);
if ($MODEL_DOWNLOAD_POOL[sanitizedModelTag]) {
toast.error(
$i18n.t(`Model '{{modelTag}}' is already in queue for downloading.`, {
modelTag: sanitizedModelTag
})
);
return;
}
if (Object.keys($MODEL_DOWNLOAD_POOL).length === 3) {
toast.error(
$i18n.t('Maximum of 3 models can be downloaded simultaneously. Please try again later.')
);
return;
}
const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => {
toast.error(error);
return null;
});
if (res) {
const reader = res.body
.pipeThrough(new TextDecoderStream())
.pipeThrough(splitStream('\n'))
.getReader();
while (true) {
try {
const { value, done } = await reader.read();
if (done) break;
let lines = value.split('\n');
for (const line of lines) {
if (line !== '') {
let data = JSON.parse(line);
console.log(data);
if (data.error) {
throw data.error;
}
if (data.detail) {
throw data.detail;
}
if (data.id) {
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
requestId: data.id,
reader,
done: false
}
});
console.log(data);
}
if (data.status) {
if (data.digest) {
let downloadProgress = 0;
if (data.completed) {
downloadProgress = Math.round((data.completed / data.total) * 1000) / 10;
} else {
downloadProgress = 100;
}
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
pullProgress: downloadProgress,
digest: data.digest
}
});
} else {
toast.success(data.status);
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
done: data.status === 'success'
}
});
}
}
}
}
} catch (error) {
console.log(error);
if (typeof error !== 'string') {
error = error.message;
}
toast.error(error);
// opts.callback({ success: false, error, modelName: opts.modelName });
}
}
if ($MODEL_DOWNLOAD_POOL[sanitizedModelTag].done) {
toast.success(
$i18n.t(`Model '{{modelName}}' has been successfully downloaded.`, {
modelName: sanitizedModelTag
})
);
models.set(await getModels(localStorage.token));
} else {
toast.error('Download canceled');
}
delete $MODEL_DOWNLOAD_POOL[sanitizedModelTag];
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL
});
}
};
onMount(async () => {
ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => false);
});
const cancelModelPullHandler = async (model: string) => {
const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model];
if (reader) {
await reader.cancel();
await cancelOllamaRequest(localStorage.token, requestId);
delete $MODEL_DOWNLOAD_POOL[model];
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL
});
await deleteModel(localStorage.token, model);
toast.success(`${model} download has been canceled`);
}
};
</script>
<Select.Root
{items}
onOpenChange={async () => {
searchValue = '';
window.setTimeout(() => document.getElementById('model-search-input')?.focus(), 0);
}}
selected={items.find((item) => item.value === value)}
onSelectedChange={(selectedItem) => {
value = selectedItem.value;
}}
>
<Select.Trigger class="relative w-full" aria-label={placeholder}>
<Select.Value
class="inline-flex h-input px-0.5 w-full outline-none bg-transparent truncate text-lg font-semibold placeholder-gray-400 focus:outline-none"
{placeholder}
/>
<ChevronDown className="absolute end-2 top-1/2 -translate-y-[45%] size-3.5" strokeWidth="2.5" />
</Select.Trigger>
<Select.Content
class="w-full rounded-lg bg-white dark:bg-gray-900 dark:text-white shadow-lg border border-gray-300/30 dark:border-gray-700/50 outline-none"
transition={flyAndScale}
sideOffset={4}
>
<slot>
{#if searchEnabled}
<div class="flex items-center gap-2.5 px-5 mt-3.5 mb-3">
<Search className="size-4" strokeWidth="2.5" />
<input
id="model-search-input"
bind:value={searchValue}
class="w-full text-sm bg-transparent outline-none"
placeholder={searchPlaceholder}
/>
</div>
<hr class="border-gray-100 dark:border-gray-800" />
{/if}
<div class="px-3 my-2 max-h-80 overflow-y-auto">
{#each filteredItems as item}
<Select.Item
class="flex w-full font-medium line-clamp-1 select-none items-center rounded-button py-2 pl-3 pr-1.5 text-sm text-gray-700 dark:text-gray-100 outline-none transition-all duration-75 hover:bg-gray-100 dark:hover:bg-gray-850 rounded-lg cursor-pointer data-[highlighted]:bg-muted"
value={item.value}
label={item.label}
>
<div class="flex items-center gap-2">
<div class="line-clamp-1">
{item.label}
<span class=" text-xs font-medium text-gray-600 dark:text-gray-400"
>{item.info?.details?.parameter_size ?? ''}</span
>
</div>
<!-- {JSON.stringify(item.info)} -->
{#if item.info.external}
<Tooltip content={item.info?.source ?? 'External'}>
<div class=" mr-2">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="size-3"
>
<path
fill-rule="evenodd"
d="M8.914 6.025a.75.75 0 0 1 1.06 0 3.5 3.5 0 0 1 0 4.95l-2 2a3.5 3.5 0 0 1-5.396-4.402.75.75 0 0 1 1.251.827 2 2 0 0 0 3.085 2.514l2-2a2 2 0 0 0 0-2.828.75.75 0 0 1 0-1.06Z"
clip-rule="evenodd"
/>
<path
fill-rule="evenodd"
d="M7.086 9.975a.75.75 0 0 1-1.06 0 3.5 3.5 0 0 1 0-4.95l2-2a3.5 3.5 0 0 1 5.396 4.402.75.75 0 0 1-1.251-.827 2 2 0 0 0-3.085-2.514l-2 2a2 2 0 0 0 0 2.828.75.75 0 0 1 0 1.06Z"
clip-rule="evenodd"
/>
</svg>
</div>
</Tooltip>
{:else}
<Tooltip
content={`${
item.info?.details?.quantization_level
? item.info?.details?.quantization_level + ' '
: ''
}${item.info.size ? `(${(item.info.size / 1024 ** 3).toFixed(1)}GB)` : ''}`}
>
<div class=" mr-2">
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-4 h-4"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="m11.25 11.25.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Zm-9-3.75h.008v.008H12V8.25Z"
/>
</svg>
</div>
</Tooltip>
{/if}
</div>
{#if value === item.value}
<div class="ml-auto">
<Check />
</div>
{/if}
</Select.Item>
{:else}
<div>
<div class="block px-3 py-2 text-sm text-gray-700 dark:text-gray-100">
No results found
</div>
</div>
{/each}
{#if !(searchValue.trim() in $MODEL_DOWNLOAD_POOL) && searchValue && ollamaVersion && $user.role === 'admin'}
<button
class="flex w-full font-medium line-clamp-1 select-none items-center rounded-button py-2 pl-3 pr-1.5 text-sm text-gray-700 dark:text-gray-100 outline-none transition-all duration-75 hover:bg-gray-100 dark:hover:bg-gray-850 rounded-lg cursor-pointer data-[highlighted]:bg-muted"
on:click={() => {
pullModelHandler();
}}
>
Pull "{searchValue}" from Ollama.com
</button>
{/if}
{#each Object.keys($MODEL_DOWNLOAD_POOL) as model}
<div
class="flex w-full justify-between font-medium select-none rounded-button py-2 pl-3 pr-1.5 text-sm text-gray-700 dark:text-gray-100 outline-none transition-all duration-75 rounded-lg cursor-pointer data-[highlighted]:bg-muted"
>
<div class="flex">
<div class="-ml-2 mr-2.5 translate-y-0.5">
<svg
class="size-4"
viewBox="0 0 24 24"
fill="currentColor"
xmlns="http://www.w3.org/2000/svg"
><style>
.spinner_ajPY {
transform-origin: center;
animation: spinner_AtaB 0.75s infinite linear;
}
@keyframes spinner_AtaB {
100% {
transform: rotate(360deg);
}
}
</style><path
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
opacity=".25"
/><path
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
class="spinner_ajPY"
/></svg
>
</div>
<div class="flex flex-col self-start">
<div class="line-clamp-1">
Downloading "{model}" {'pullProgress' in $MODEL_DOWNLOAD_POOL[model]
? `(${$MODEL_DOWNLOAD_POOL[model].pullProgress}%)`
: ''}
</div>
{#if 'digest' in $MODEL_DOWNLOAD_POOL[model] && $MODEL_DOWNLOAD_POOL[model].digest}
<div class="-mt-1 h-fit text-[0.7rem] dark:text-gray-500 line-clamp-1">
{$MODEL_DOWNLOAD_POOL[model].digest}
</div>
{/if}
</div>
</div>
<div class="mr-2 translate-y-0.5">
<Tooltip content="Cancel">
<button
class="text-gray-800 dark:text-gray-100"
on:click={() => {
cancelModelPullHandler(model);
}}
>
<svg
class="w-4 h-4 text-gray-800 dark:text-white"
aria-hidden="true"
xmlns="http://www.w3.org/2000/svg"
width="24"
height="24"
fill="currentColor"
viewBox="0 0 24 24"
>
<path
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M6 18 17.94 6M18 18 6.06 6"
/>
</svg>
</button>
</Tooltip>
</div>
</div>
{/each}
</div>
</slot>
</Select.Content>
</Select.Root>

View file

@ -2,7 +2,6 @@
import fileSaver from 'file-saver'; import fileSaver from 'file-saver';
const { saveAs } = fileSaver; const { saveAs } = fileSaver;
import { resetVectorDB } from '$lib/apis/rag';
import { chats, user } from '$lib/stores'; import { chats, user } from '$lib/stores';
import { import {
@ -330,38 +329,6 @@
{$i18n.t('Export All Chats (All Users)')} {$i18n.t('Export All Chats (All Users)')}
</div> </div>
</button> </button>
<hr class=" dark:border-gray-700" />
<button
class=" flex rounded-md py-2 px-3.5 w-full hover:bg-gray-200 dark:hover:bg-gray-800 transition"
on:click={() => {
const res = resetVectorDB(localStorage.token).catch((error) => {
toast.error(error);
return null;
});
if (res) {
toast.success($i18n.t('Success'));
}
}}
>
<div class=" self-center mr-3">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M3.5 2A1.5 1.5 0 0 0 2 3.5v9A1.5 1.5 0 0 0 3.5 14h9a1.5 1.5 0 0 0 1.5-1.5v-7A1.5 1.5 0 0 0 12.5 4H9.621a1.5 1.5 0 0 1-1.06-.44L7.439 2.44A1.5 1.5 0 0 0 6.38 2H3.5Zm6.75 7.75a.75.75 0 0 0 0-1.5h-4.5a.75.75 0 0 0 0 1.5h4.5Z"
clip-rule="evenodd"
/>
</svg>
</div>
<div class=" self-center text-sm font-medium">{$i18n.t('Reset Vector Storage')}</div>
</button>
{/if} {/if}
</div> </div>
</div> </div>

View file

@ -1,7 +1,7 @@
<script lang="ts"> <script lang="ts">
import { getBackendConfig } from '$lib/apis'; import { getBackendConfig } from '$lib/apis';
import { setDefaultPromptSuggestions } from '$lib/apis/configs'; import { setDefaultPromptSuggestions } from '$lib/apis/configs';
import { config, models, user } from '$lib/stores'; import { config, models, settings, user } from '$lib/stores';
import { createEventDispatcher, onMount, getContext } from 'svelte'; import { createEventDispatcher, onMount, getContext } from 'svelte';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
@ -14,6 +14,7 @@
let titleAutoGenerate = true; let titleAutoGenerate = true;
let responseAutoCopy = false; let responseAutoCopy = false;
let titleAutoGenerateModel = ''; let titleAutoGenerateModel = '';
let titleAutoGenerateModelExternal = '';
let fullScreenMode = false; let fullScreenMode = false;
let titleGenerationPrompt = ''; let titleGenerationPrompt = '';
@ -33,7 +34,12 @@
const toggleTitleAutoGenerate = async () => { const toggleTitleAutoGenerate = async () => {
titleAutoGenerate = !titleAutoGenerate; titleAutoGenerate = !titleAutoGenerate;
saveSettings({ titleAutoGenerate: titleAutoGenerate }); saveSettings({
title: {
...$settings.title,
auto: titleAutoGenerate
}
});
}; };
const toggleResponseAutoCopy = async () => { const toggleResponseAutoCopy = async () => {
@ -65,8 +71,13 @@
} }
saveSettings({ saveSettings({
titleAutoGenerateModel: titleAutoGenerateModel !== '' ? titleAutoGenerateModel : undefined, title: {
titleGenerationPrompt: titleGenerationPrompt ? titleGenerationPrompt : undefined ...$settings.title,
model: titleAutoGenerateModel !== '' ? titleAutoGenerateModel : undefined,
modelExternal:
titleAutoGenerateModelExternal !== '' ? titleAutoGenerateModelExternal : undefined,
prompt: titleGenerationPrompt ? titleGenerationPrompt : undefined
}
}); });
}; };
@ -77,16 +88,18 @@
let settings = JSON.parse(localStorage.getItem('settings') ?? '{}'); let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
titleAutoGenerate = settings.titleAutoGenerate ?? true; titleAutoGenerate = settings?.title?.auto ?? true;
responseAutoCopy = settings.responseAutoCopy ?? false; titleAutoGenerateModel = settings?.title?.model ?? '';
showUsername = settings.showUsername ?? false; titleAutoGenerateModelExternal = settings?.title?.modelExternal ?? '';
fullScreenMode = settings.fullScreenMode ?? false;
titleAutoGenerateModel = settings.titleAutoGenerateModel ?? '';
titleGenerationPrompt = titleGenerationPrompt =
settings.titleGenerationPrompt ?? settings?.title?.prompt ??
$i18n.t( $i18n.t(
"Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title':" "Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title':"
) + ' {{prompt}}'; ) + ' {{prompt}}';
responseAutoCopy = settings.responseAutoCopy ?? false;
showUsername = settings.showUsername ?? false;
fullScreenMode = settings.fullScreenMode ?? false;
}); });
</script> </script>
@ -190,8 +203,9 @@
<div> <div>
<div class=" mb-2.5 text-sm font-medium">{$i18n.t('Set Title Auto-Generation Model')}</div> <div class=" mb-2.5 text-sm font-medium">{$i18n.t('Set Title Auto-Generation Model')}</div>
<div class="flex w-full"> <div class="flex w-full gap-2 pr-2">
<div class="flex-1 mr-2"> <div class="flex-1">
<div class=" text-xs mb-1">Local Models</div>
<select <select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={titleAutoGenerateModel} bind:value={titleAutoGenerateModel}
@ -207,6 +221,24 @@
{/each} {/each}
</select> </select>
</div> </div>
<div class="flex-1">
<div class=" text-xs mb-1">External Models</div>
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={titleAutoGenerateModelExternal}
placeholder={$i18n.t('Select a model')}
>
<option value="" selected>{$i18n.t('Current Model')}</option>
{#each $models as model}
{#if model.name !== 'hr'}
<option value={model.name} class="bg-gray-100 dark:bg-gray-700">
{model.name}
</option>
{/if}
{/each}
</select>
</div>
</div> </div>
<div class="mt-3 mr-2"> <div class="mt-3 mr-2">

View file

@ -9,6 +9,7 @@
getOllamaUrls, getOllamaUrls,
getOllamaVersion, getOllamaVersion,
pullModel, pullModel,
cancelOllamaRequest,
uploadModel uploadModel
} from '$lib/apis/ollama'; } from '$lib/apis/ollama';
import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
@ -163,7 +164,7 @@
// Remove the downloaded model // Remove the downloaded model
delete modelDownloadStatus[modelName]; delete modelDownloadStatus[modelName];
console.log(data); modelDownloadStatus = { ...modelDownloadStatus };
if (!data.success) { if (!data.success) {
toast.error(data.error); toast.error(data.error);
@ -372,12 +373,24 @@
for (const line of lines) { for (const line of lines) {
if (line !== '') { if (line !== '') {
let data = JSON.parse(line); let data = JSON.parse(line);
console.log(data);
if (data.error) { if (data.error) {
throw data.error; throw data.error;
} }
if (data.detail) { if (data.detail) {
throw data.detail; throw data.detail;
} }
if (data.id) {
modelDownloadStatus[opts.modelName] = {
...modelDownloadStatus[opts.modelName],
requestId: data.id,
reader,
done: false
};
console.log(data);
}
if (data.status) { if (data.status) {
if (data.digest) { if (data.digest) {
let downloadProgress = 0; let downloadProgress = 0;
@ -387,11 +400,17 @@
downloadProgress = 100; downloadProgress = 100;
} }
modelDownloadStatus[opts.modelName] = { modelDownloadStatus[opts.modelName] = {
...modelDownloadStatus[opts.modelName],
pullProgress: downloadProgress, pullProgress: downloadProgress,
digest: data.digest digest: data.digest
}; };
} else { } else {
toast.success(data.status); toast.success(data.status);
modelDownloadStatus[opts.modelName] = {
...modelDownloadStatus[opts.modelName],
done: data.status === 'success'
};
} }
} }
} }
@ -404,7 +423,14 @@
opts.callback({ success: false, error, modelName: opts.modelName }); opts.callback({ success: false, error, modelName: opts.modelName });
} }
} }
opts.callback({ success: true, modelName: opts.modelName });
console.log(modelDownloadStatus[opts.modelName]);
if (modelDownloadStatus[opts.modelName].done) {
opts.callback({ success: true, modelName: opts.modelName });
} else {
opts.callback({ success: false, error: 'Download canceled', modelName: opts.modelName });
}
} }
}; };
@ -474,6 +500,18 @@
ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => false); ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => false);
liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token); liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token);
}); });
const cancelModelPullHandler = async (model: string) => {
const { reader, requestId } = modelDownloadStatus[model];
if (reader) {
await reader.cancel();
await cancelOllamaRequest(localStorage.token, requestId);
delete modelDownloadStatus[model];
await deleteModel(localStorage.token, model);
toast.success(`${model} download has been canceled`);
}
};
</script> </script>
<div class="flex flex-col h-full justify-between text-sm"> <div class="flex flex-col h-full justify-between text-sm">
@ -604,20 +642,58 @@
{#if Object.keys(modelDownloadStatus).length > 0} {#if Object.keys(modelDownloadStatus).length > 0}
{#each Object.keys(modelDownloadStatus) as model} {#each Object.keys(modelDownloadStatus) as model}
<div class="flex flex-col"> {#if 'pullProgress' in modelDownloadStatus[model]}
<div class="font-medium mb-1">{model}</div> <div class="flex flex-col">
<div class=""> <div class="font-medium mb-1">{model}</div>
<div <div class="">
class="dark:bg-gray-600 bg-gray-500 text-xs font-medium text-gray-100 text-center p-0.5 leading-none rounded-full" <div class="flex flex-row justify-between space-x-4 pr-2">
style="width: {Math.max(15, modelDownloadStatus[model].pullProgress ?? 0)}%" <div class=" flex-1">
> <div
{modelDownloadStatus[model].pullProgress ?? 0}% class="dark:bg-gray-600 bg-gray-500 text-xs font-medium text-gray-100 text-center p-0.5 leading-none rounded-full"
</div> style="width: {Math.max(
<div class="mt-1 text-xs dark:text-gray-500" style="font-size: 0.5rem;"> 15,
{modelDownloadStatus[model].digest} modelDownloadStatus[model].pullProgress ?? 0
)}%"
>
{modelDownloadStatus[model].pullProgress ?? 0}%
</div>
</div>
<Tooltip content="Cancel">
<button
class="text-gray-800 dark:text-gray-100"
on:click={() => {
cancelModelPullHandler(model);
}}
>
<svg
class="w-4 h-4 text-gray-800 dark:text-white"
aria-hidden="true"
xmlns="http://www.w3.org/2000/svg"
width="24"
height="24"
fill="currentColor"
viewBox="0 0 24 24"
>
<path
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M6 18 17.94 6M18 18 6.06 6"
/>
</svg>
</button>
</Tooltip>
</div>
{#if 'digest' in modelDownloadStatus[model]}
<div class="mt-1 text-xs dark:text-gray-500" style="font-size: 0.5rem;">
{modelDownloadStatus[model].digest}
</div>
{/if}
</div> </div>
</div> </div>
</div> {/if}
{/each} {/each}
{/if} {/if}
</div> </div>

View file

@ -2,6 +2,8 @@
import { DropdownMenu } from 'bits-ui'; import { DropdownMenu } from 'bits-ui';
import { createEventDispatcher } from 'svelte'; import { createEventDispatcher } from 'svelte';
import { flyAndScale } from '$lib/utils/transitions';
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
</script> </script>
@ -20,6 +22,7 @@
sideOffset={8} sideOffset={8}
side="bottom" side="bottom"
align="start" align="start"
transition={flyAndScale}
> >
<DropdownMenu.Item class="flex items-center px-3 py-2 text-sm font-medium"> <DropdownMenu.Item class="flex items-center px-3 py-2 text-sm font-medium">
<div class="flex items-center">Profile</div> <div class="flex items-center">Profile</div>

View file

@ -2,6 +2,8 @@
import { onMount } from 'svelte'; import { onMount } from 'svelte';
import { fade } from 'svelte/transition'; import { fade } from 'svelte/transition';
import { flyAndScale } from '$lib/utils/transitions';
export let show = true; export let show = true;
export let size = 'md'; export let size = 'md';
@ -41,10 +43,10 @@
}} }}
> >
<div <div
class=" modal-content m-auto rounded-2xl max-w-full {sizeToWidth( class=" m-auto rounded-2xl max-w-full {sizeToWidth(
size size
)} mx-2 bg-gray-50 dark:bg-gray-900 shadow-3xl" )} mx-2 bg-gray-50 dark:bg-gray-900 shadow-3xl"
in:fade={{ duration: 10 }} in:flyAndScale
on:click={(e) => { on:click={(e) => {
e.stopPropagation(); e.stopPropagation();
}} }}

View file

@ -0,0 +1,95 @@
<script lang="ts">
import { Select } from 'bits-ui';
import { flyAndScale } from '$lib/utils/transitions';
import { createEventDispatcher } from 'svelte';
import ChevronDown from '../icons/ChevronDown.svelte';
import Check from '../icons/Check.svelte';
import Search from '../icons/Search.svelte';
const dispatch = createEventDispatcher();
export let value = '';
export let placeholder = 'Select a model';
export let searchEnabled = true;
export let searchPlaceholder = 'Search a model';
export let items = [
{ value: 'mango', label: 'Mango' },
{ value: 'watermelon', label: 'Watermelon' },
{ value: 'apple', label: 'Apple' },
{ value: 'pineapple', label: 'Pineapple' },
{ value: 'orange', label: 'Orange' }
];
let searchValue = '';
$: filteredItems = searchValue
? items.filter((item) => item.value.includes(searchValue.toLowerCase()))
: items;
</script>
<Select.Root
{items}
onOpenChange={() => {
searchValue = '';
}}
selected={items.find((item) => item.value === value)}
onSelectedChange={(selectedItem) => {
value = selectedItem.value;
}}
>
<Select.Trigger class="relative w-full" aria-label={placeholder}>
<Select.Value
class="inline-flex h-input px-0.5 w-full outline-none bg-transparent truncate text-lg font-semibold placeholder-gray-400 focus:outline-none"
{placeholder}
/>
<ChevronDown className="absolute end-2 top-1/2 -translate-y-[45%] size-3.5" strokeWidth="2.5" />
</Select.Trigger>
<Select.Content
class="w-full rounded-lg bg-white dark:bg-gray-900 dark:text-white shadow-lg border border-gray-300/30 dark:border-gray-700/50 outline-none"
transition={flyAndScale}
sideOffset={4}
>
<slot>
{#if searchEnabled}
<div class="flex items-center gap-2.5 px-5 mt-3.5 mb-3">
<Search className="size-4" strokeWidth="2.5" />
<input
bind:value={searchValue}
class="w-full text-sm bg-transparent outline-none"
placeholder={searchPlaceholder}
/>
</div>
<hr class="border-gray-100 dark:border-gray-800" />
{/if}
<div class="px-3 my-2 max-h-80 overflow-y-auto">
{#each filteredItems as item}
<Select.Item
class="flex w-full font-medium line-clamp-1 select-none items-center rounded-button py-2 pl-3 pr-1.5 text-sm text-gray-700 dark:text-gray-100 outline-none transition-all duration-75 hover:bg-gray-100 dark:hover:bg-gray-850 rounded-lg cursor-pointer data-[highlighted]:bg-muted"
value={item.value}
label={item.label}
>
{item.label}
{#if value === item.value}
<div class="ml-auto">
<Check />
</div>
{/if}
</Select.Item>
{:else}
<div>
<div class="block px-5 py-2 text-sm text-gray-700 dark:text-gray-100">
No results found
</div>
</div>
{/each}
</div>
</slot>
</Select.Content>
</Select.Root>

View file

@ -5,8 +5,10 @@
updateRAGConfig, updateRAGConfig,
getQuerySettings, getQuerySettings,
scanDocs, scanDocs,
updateQuerySettings updateQuerySettings,
resetVectorDB
} from '$lib/apis/rag'; } from '$lib/apis/rag';
import { documents } from '$lib/stores'; import { documents } from '$lib/stores';
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
@ -17,6 +19,8 @@
let loading = false; let loading = false;
let showResetConfirm = false;
let chunkSize = 0; let chunkSize = 0;
let chunkOverlap = 0; let chunkOverlap = 0;
let pdfExtractImages = true; let pdfExtractImages = true;
@ -231,6 +235,100 @@
/> />
</div> </div>
</div> </div>
<hr class=" dark:border-gray-700" />
{#if showResetConfirm}
<div class="flex justify-between rounded-md items-center py-2 px-3.5 w-full transition">
<div class="flex items-center space-x-3">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path d="M2 3a1 1 0 0 1 1-1h10a1 1 0 0 1 1 1v1a1 1 0 0 1-1 1H3a1 1 0 0 1-1-1V3Z" />
<path
fill-rule="evenodd"
d="M13 6H3v6a2 2 0 0 0 2 2h6a2 2 0 0 0 2-2V6ZM5.72 7.47a.75.75 0 0 1 1.06 0L8 8.69l1.22-1.22a.75.75 0 1 1 1.06 1.06L9.06 9.75l1.22 1.22a.75.75 0 1 1-1.06 1.06L8 10.81l-1.22 1.22a.75.75 0 0 1-1.06-1.06l1.22-1.22-1.22-1.22a.75.75 0 0 1 0-1.06Z"
clip-rule="evenodd"
/>
</svg>
<span>{$i18n.t('Are you sure?')}</span>
</div>
<div class="flex space-x-1.5 items-center">
<button
class="hover:text-white transition"
on:click={() => {
const res = resetVectorDB(localStorage.token).catch((error) => {
toast.error(error);
return null;
});
if (res) {
toast.success($i18n.t('Success'));
}
showResetConfirm = false;
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M16.704 4.153a.75.75 0 01.143 1.052l-8 10.5a.75.75 0 01-1.127.075l-4.5-4.5a.75.75 0 011.06-1.06l3.894 3.893 7.48-9.817a.75.75 0 011.05-.143z"
clip-rule="evenodd"
/>
</svg>
</button>
<button
class="hover:text-white transition"
on:click={() => {
showResetConfirm = false;
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button>
</div>
</div>
{:else}
<button
class=" flex rounded-md py-2 px-3.5 w-full hover:bg-gray-200 dark:hover:bg-gray-800 transition"
on:click={() => {
showResetConfirm = true;
}}
>
<div class=" self-center mr-3">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M3.5 2A1.5 1.5 0 0 0 2 3.5v9A1.5 1.5 0 0 0 3.5 14h9a1.5 1.5 0 0 0 1.5-1.5v-7A1.5 1.5 0 0 0 12.5 4H9.621a1.5 1.5 0 0 1-1.06-.44L7.439 2.44A1.5 1.5 0 0 0 6.38 2H3.5Zm6.75 7.75a.75.75 0 0 0 0-1.5h-4.5a.75.75 0 0 0 0 1.5h4.5Z"
clip-rule="evenodd"
/>
</svg>
</div>
<div class=" self-center text-sm font-medium">{$i18n.t('Reset Vector Storage')}</div>
</button>
{/if}
</div> </div>
<div class="flex justify-end pt-3 text-sm font-medium"> <div class="flex justify-end pt-3 text-sm font-medium">

View file

@ -0,0 +1,15 @@
<script lang="ts">
export let className = 'w-4 h-4';
export let strokeWidth = '1.5';
</script>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width={strokeWidth}
stroke="currentColor"
class={className}
>
<path stroke-linecap="round" stroke-linejoin="round" d="m4.5 12.75 6 6 9-13.5" />
</svg>

View file

@ -0,0 +1,15 @@
<script lang="ts">
export let className = 'w-4 h-4';
export let strokeWidth = '1.5';
</script>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width={strokeWidth}
stroke="currentColor"
class={className}
>
<path stroke-linecap="round" stroke-linejoin="round" d="m19.5 8.25-7.5 7.5-7.5-7.5" />
</svg>

View file

@ -0,0 +1,19 @@
<script lang="ts">
export let className = 'w-4 h-4';
export let strokeWidth = '1.5';
</script>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width={strokeWidth}
stroke="currentColor"
class={className}
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="m21 21-5.197-5.197m0 0A7.5 7.5 0 1 0 5.196 5.196a7.5 7.5 0 0 0 10.607 10.607Z"
/>
</svg>

View file

@ -20,10 +20,9 @@
getAllChatTags getAllChatTags
} from '$lib/apis/chats'; } from '$lib/apis/chats';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import { slide } from 'svelte/transition'; import { fade, slide } from 'svelte/transition';
import { WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_BASE_URL } from '$lib/constants';
import Tooltip from '../common/Tooltip.svelte'; import Tooltip from '../common/Tooltip.svelte';
import Dropdown from '../common/Dropdown.svelte';
import ChatMenu from './Sidebar/ChatMenu.svelte'; import ChatMenu from './Sidebar/ChatMenu.svelte';
let show = false; let show = false;
@ -577,7 +576,7 @@
<div <div
id="dropdownDots" id="dropdownDots"
class="absolute z-40 bottom-[70px] 4.5rem rounded-xl shadow w-[240px] bg-white dark:bg-gray-900" class="absolute z-40 bottom-[70px] 4.5rem rounded-xl shadow w-[240px] bg-white dark:bg-gray-900"
in:slide={{ duration: 150 }} transition:fade|slide={{ duration: 100 }}
> >
<div class="py-2 w-full"> <div class="py-2 w-full">
{#if $user.role === 'admin'} {#if $user.role === 'admin'}

View file

@ -1,5 +1,6 @@
<script lang="ts"> <script lang="ts">
import { DropdownMenu } from 'bits-ui'; import { DropdownMenu } from 'bits-ui';
import { flyAndScale } from '$lib/utils/transitions';
import Dropdown from '$lib/components/common/Dropdown.svelte'; import Dropdown from '$lib/components/common/Dropdown.svelte';
import GarbageBin from '$lib/components/icons/GarbageBin.svelte'; import GarbageBin from '$lib/components/icons/GarbageBin.svelte';
@ -29,6 +30,7 @@
sideOffset={-2} sideOffset={-2}
side="bottom" side="bottom"
align="start" align="start"
transition={flyAndScale}
> >
<DropdownMenu.Item <DropdownMenu.Item
class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer" class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer"

View file

@ -22,6 +22,7 @@ export const SUPPORTED_FILE_TYPE = [
'text/plain', 'text/plain',
'text/csv', 'text/csv',
'text/xml', 'text/xml',
'text/html',
'text/x-python', 'text/x-python',
'text/css', 'text/css',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
@ -50,6 +51,8 @@ export const SUPPORTED_FILE_EXTENSIONS = [
'h', 'h',
'c', 'c',
'cs', 'cs',
'htm',
'html',
'sql', 'sql',
'log', 'log',
'ini', 'ini',

View file

@ -279,7 +279,7 @@
"Send a Message": "Изпращане на Съобщение", "Send a Message": "Изпращане на Съобщение",
"Send message": "Изпращане на съобщение", "Send message": "Изпращане на съобщение",
"Server connection verified": "Server connection verified", "Server connection verified": "Server connection verified",
"Set as default": "Задай като по подразбиране", "Set as default": "Задай по подразбиране",
"Set Default Model": "Задай Модел По Подразбиране", "Set Default Model": "Задай Модел По Подразбиране",
"Set Image Size": "Задай Размер на Изображението", "Set Image Size": "Задай Размер на Изображението",
"Set Steps": "Задай Стъпки", "Set Steps": "Задай Стъпки",
@ -320,7 +320,7 @@
"Title": "Заглавие", "Title": "Заглавие",
"Title Auto-Generation": "Автоматично Генериране на Заглавие", "Title Auto-Generation": "Автоматично Генериране на Заглавие",
"Title Generation Prompt": "Промпт за Генериране на Заглавие", "Title Generation Prompt": "Промпт за Генериране на Заглавие",
"to": "до", "to": "в",
"To access the available model names for downloading,": "За да получите достъп до наличните имена на модели за изтегляне,", "To access the available model names for downloading,": "За да получите достъп до наличните имена на модели за изтегляне,",
"To access the GGUF models available for downloading,": "За да получите достъп до GGUF моделите, налични за изтегляне,", "To access the GGUF models available for downloading,": "За да получите достъп до GGUF моделите, налични за изтегляне,",
"to chat input.": "към чат входа.", "to chat input.": "към чат входа.",

View file

@ -7,8 +7,9 @@ export const config = writable(undefined);
export const user = writable(undefined); export const user = writable(undefined);
// Frontend // Frontend
export const theme = writable('system'); export const MODEL_DOWNLOAD_POOL = writable({});
export const theme = writable('system');
export const chatId = writable(''); export const chatId = writable('');
export const chats = writable([]); export const chats = writable([]);

View file

@ -1,10 +1,40 @@
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import sha256 from 'js-sha256'; import sha256 from 'js-sha256';
import { getOllamaModels } from '$lib/apis/ollama';
import { getOpenAIModels } from '$lib/apis/openai';
import { getLiteLLMModels } from '$lib/apis/litellm';
export const getModels = async (token: string) => {
let models = await Promise.all([
await getOllamaModels(token).catch((error) => {
console.log(error);
return null;
}),
await getOpenAIModels(token).catch((error) => {
console.log(error);
return null;
}),
await getLiteLLMModels(token).catch((error) => {
console.log(error);
return null;
})
]);
models = models
.filter((models) => models)
.reduce((a, e, i, arr) => a.concat(e, ...(i < arr.length - 1 ? [{ name: 'hr' }] : [])), []);
return models;
};
////////////////////////// //////////////////////////
// Helper functions // Helper functions
////////////////////////// //////////////////////////
export const capitalizeFirstLetter = (string) => {
return string.charAt(0).toUpperCase() + string.slice(1);
};
export const splitStream = (splitOn) => { export const splitStream = (splitOn) => {
let buffer = ''; let buffer = '';
return new TransformStream({ return new TransformStream({

View file

@ -0,0 +1,48 @@
import { cubicOut } from 'svelte/easing';
import type { TransitionConfig } from 'svelte/transition';
type FlyAndScaleParams = {
y?: number;
start?: number;
duration?: number;
};
const defaultFlyAndScaleParams = { y: -8, start: 0.95, duration: 200 };
export const flyAndScale = (node: Element, params?: FlyAndScaleParams): TransitionConfig => {
const style = getComputedStyle(node);
const transform = style.transform === 'none' ? '' : style.transform;
const withDefaults = { ...defaultFlyAndScaleParams, ...params };
const scaleConversion = (valueA: number, scaleA: [number, number], scaleB: [number, number]) => {
const [minA, maxA] = scaleA;
const [minB, maxB] = scaleB;
const percentage = (valueA - minA) / (maxA - minA);
const valueB = percentage * (maxB - minB) + minB;
return valueB;
};
const styleToString = (style: Record<string, number | string | undefined>): string => {
return Object.keys(style).reduce((str, key) => {
if (style[key] === undefined) return str;
return str + `${key}:${style[key]};`;
}, '');
};
return {
duration: withDefaults.duration ?? 200,
delay: 0,
css: (t) => {
const y = scaleConversion(t, [0, 1], [withDefaults.y, 0]);
const scale = scaleConversion(t, [0, 1], [withDefaults.start, 1]);
return styleToString({
transform: `${transform} translate3d(0, ${y}px, 0) scale(${scale})`,
opacity: t
});
},
easing: cubicOut
};
};

View file

@ -19,7 +19,7 @@
} from '$lib/stores'; } from '$lib/stores';
import { copyToClipboard, splitStream } from '$lib/utils'; import { copyToClipboard, splitStream } from '$lib/utils';
import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama'; import { generateChatCompletion, cancelOllamaRequest } from '$lib/apis/ollama';
import { import {
addTagById, addTagById,
createNewChat, createNewChat,
@ -30,14 +30,14 @@
updateChatById updateChatById
} from '$lib/apis/chats'; } from '$lib/apis/chats';
import { queryCollection, queryDoc } from '$lib/apis/rag'; import { queryCollection, queryDoc } from '$lib/apis/rag';
import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import { generateOpenAIChatCompletion, generateTitle } from '$lib/apis/openai';
import MessageInput from '$lib/components/chat/MessageInput.svelte'; import MessageInput from '$lib/components/chat/MessageInput.svelte';
import Messages from '$lib/components/chat/Messages.svelte'; import Messages from '$lib/components/chat/Messages.svelte';
import ModelSelector from '$lib/components/chat/ModelSelector.svelte'; import ModelSelector from '$lib/components/chat/ModelSelector.svelte';
import Navbar from '$lib/components/layout/Navbar.svelte'; import Navbar from '$lib/components/layout/Navbar.svelte';
import { RAGTemplate } from '$lib/utils/rag'; import { RAGTemplate } from '$lib/utils/rag';
import { LITELLM_API_BASE_URL, OPENAI_API_BASE_URL } from '$lib/constants'; import { LITELLM_API_BASE_URL, OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL } from '$lib/constants';
import { WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_BASE_URL } from '$lib/constants';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -104,7 +104,7 @@
const initNewChat = async () => { const initNewChat = async () => {
if (currentRequestId !== null) { if (currentRequestId !== null) {
await cancelChatCompletion(localStorage.token, currentRequestId); await cancelOllamaRequest(localStorage.token, currentRequestId);
currentRequestId = null; currentRequestId = null;
} }
window.history.replaceState(history.state, '', `/`); window.history.replaceState(history.state, '', `/`);
@ -372,7 +372,7 @@
if (stopResponseFlag) { if (stopResponseFlag) {
controller.abort('User: Stop Response'); controller.abort('User: Stop Response');
await cancelChatCompletion(localStorage.token, currentRequestId); await cancelOllamaRequest(localStorage.token, currentRequestId);
} }
currentRequestId = null; currentRequestId = null;
@ -511,7 +511,8 @@
if (messages.length == 2 && messages.at(1).content !== '') { if (messages.length == 2 && messages.at(1).content !== '') {
window.history.replaceState(history.state, '', `/c/${_chatId}`); window.history.replaceState(history.state, '', `/c/${_chatId}`);
await generateChatTitle(_chatId, userPrompt); const _title = await generateChatTitle(userPrompt);
await setChatTitle(_chatId, _title);
} }
}; };
@ -696,11 +697,8 @@
if (messages.length == 2) { if (messages.length == 2) {
window.history.replaceState(history.state, '', `/c/${_chatId}`); window.history.replaceState(history.state, '', `/c/${_chatId}`);
if ($settings?.titleAutoGenerateModel) { const _title = await generateChatTitle(userPrompt);
await generateChatTitle(_chatId, userPrompt); await setChatTitle(_chatId, _title);
} else {
await setChatTitle(_chatId, userPrompt);
}
} }
}; };
@ -754,23 +752,46 @@
} }
}; };
const generateChatTitle = async (_chatId, userPrompt) => { const generateChatTitle = async (userPrompt) => {
if ($settings.titleAutoGenerate ?? true) { if ($settings?.title?.auto ?? true) {
const model = $models.find((model) => model.id === selectedModels[0]);
const titleModelId =
model?.external ?? false
? $settings?.title?.modelExternal ?? selectedModels[0]
: $settings?.title?.model ?? selectedModels[0];
const titleModel = $models.find((model) => model.id === titleModelId);
console.log(titleModel);
const title = await generateTitle( const title = await generateTitle(
localStorage.token, localStorage.token,
$settings?.titleGenerationPrompt ?? $settings?.title?.prompt ??
$i18n.t( $i18n.t(
"Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title':" "Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title':"
) + ' {{prompt}}', ) + ' {{prompt}}',
$settings?.titleAutoGenerateModel ?? selectedModels[0], titleModelId,
userPrompt userPrompt,
titleModel?.external ?? false
? titleModel.source === 'litellm'
? `${LITELLM_API_BASE_URL}/v1`
: `${OPENAI_API_BASE_URL}`
: `${OLLAMA_API_BASE_URL}/v1`
); );
if (title) { return title;
await setChatTitle(_chatId, title);
}
} else { } else {
await setChatTitle(_chatId, `${userPrompt}`); return `${userPrompt}`;
}
};
const setChatTitle = async (_chatId, _title) => {
if (_chatId === $chatId) {
title = _title;
}
if ($settings.saveChatHistory ?? true) {
chat = await updateChatById(localStorage.token, _chatId, { title: _title });
await chats.set(await getChatList(localStorage.token));
} }
}; };
@ -801,17 +822,6 @@
_tags.set(await getAllChatTags(localStorage.token)); _tags.set(await getAllChatTags(localStorage.token));
}; };
const setChatTitle = async (_chatId, _title) => {
if (_chatId === $chatId) {
title = _title;
}
if ($settings.saveChatHistory ?? true) {
chat = await updateChatById(localStorage.token, _chatId, { title: _title });
await chats.set(await getChatList(localStorage.token));
}
};
</script> </script>
<svelte:head> <svelte:head>

View file

@ -19,7 +19,7 @@
} from '$lib/stores'; } from '$lib/stores';
import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils'; import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils';
import { generateChatCompletion, generateTitle, cancelChatCompletion } from '$lib/apis/ollama'; import { generateChatCompletion, cancelOllamaRequest } from '$lib/apis/ollama';
import { import {
addTagById, addTagById,
createNewChat, createNewChat,
@ -31,14 +31,19 @@
updateChatById updateChatById
} from '$lib/apis/chats'; } from '$lib/apis/chats';
import { queryCollection, queryDoc } from '$lib/apis/rag'; import { queryCollection, queryDoc } from '$lib/apis/rag';
import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import { generateOpenAIChatCompletion, generateTitle } from '$lib/apis/openai';
import MessageInput from '$lib/components/chat/MessageInput.svelte'; import MessageInput from '$lib/components/chat/MessageInput.svelte';
import Messages from '$lib/components/chat/Messages.svelte'; import Messages from '$lib/components/chat/Messages.svelte';
import ModelSelector from '$lib/components/chat/ModelSelector.svelte'; import ModelSelector from '$lib/components/chat/ModelSelector.svelte';
import Navbar from '$lib/components/layout/Navbar.svelte'; import Navbar from '$lib/components/layout/Navbar.svelte';
import { RAGTemplate } from '$lib/utils/rag'; import { RAGTemplate } from '$lib/utils/rag';
import { LITELLM_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import {
LITELLM_API_BASE_URL,
OPENAI_API_BASE_URL,
OLLAMA_API_BASE_URL,
WEBUI_BASE_URL
} from '$lib/constants';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -382,7 +387,7 @@
if (stopResponseFlag) { if (stopResponseFlag) {
controller.abort('User: Stop Response'); controller.abort('User: Stop Response');
await cancelChatCompletion(localStorage.token, currentRequestId); await cancelOllamaRequest(localStorage.token, currentRequestId);
} }
currentRequestId = null; currentRequestId = null;
@ -521,7 +526,8 @@
if (messages.length == 2 && messages.at(1).content !== '') { if (messages.length == 2 && messages.at(1).content !== '') {
window.history.replaceState(history.state, '', `/c/${_chatId}`); window.history.replaceState(history.state, '', `/c/${_chatId}`);
await generateChatTitle(_chatId, userPrompt); const _title = await generateChatTitle(userPrompt);
await setChatTitle(_chatId, _title);
} }
}; };
@ -706,11 +712,8 @@
if (messages.length == 2) { if (messages.length == 2) {
window.history.replaceState(history.state, '', `/c/${_chatId}`); window.history.replaceState(history.state, '', `/c/${_chatId}`);
if ($settings?.titleAutoGenerateModel) { const _title = await generateChatTitle(userPrompt);
await generateChatTitle(_chatId, userPrompt); await setChatTitle(_chatId, _title);
} else {
await setChatTitle(_chatId, userPrompt);
}
} }
}; };
@ -719,6 +722,19 @@
console.log('stopResponse'); console.log('stopResponse');
}; };
const regenerateResponse = async () => {
console.log('regenerateResponse');
if (messages.length != 0 && messages.at(-1).done == true) {
messages.splice(messages.length - 1, 1);
messages = messages;
let userMessage = messages.at(-1);
let userPrompt = userMessage.content;
await sendPrompt(userPrompt, userMessage.id);
}
};
const continueGeneration = async () => { const continueGeneration = async () => {
console.log('continueGeneration'); console.log('continueGeneration');
const _chatId = JSON.parse(JSON.stringify($chatId)); const _chatId = JSON.parse(JSON.stringify($chatId));
@ -751,36 +767,35 @@
} }
}; };
const regenerateResponse = async () => { const generateChatTitle = async (userPrompt) => {
console.log('regenerateResponse'); if ($settings?.title?.auto ?? true) {
if (messages.length != 0 && messages.at(-1).done == true) { const model = $models.find((model) => model.id === selectedModels[0]);
messages.splice(messages.length - 1, 1);
messages = messages;
let userMessage = messages.at(-1); const titleModelId =
let userPrompt = userMessage.content; model?.external ?? false
? $settings?.title?.modelExternal ?? selectedModels[0]
: $settings?.title?.model ?? selectedModels[0];
const titleModel = $models.find((model) => model.id === titleModelId);
await sendPrompt(userPrompt, userMessage.id); console.log(titleModel);
}
};
const generateChatTitle = async (_chatId, userPrompt) => {
if ($settings.titleAutoGenerate ?? true) {
const title = await generateTitle( const title = await generateTitle(
localStorage.token, localStorage.token,
$settings?.titleGenerationPrompt ?? $settings?.title?.prompt ??
$i18n.t( $i18n.t(
"Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title':" "Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title':"
) + ' {{prompt}}', ) + ' {{prompt}}',
$settings?.titleAutoGenerateModel ?? selectedModels[0], titleModelId,
userPrompt userPrompt,
titleModel?.external ?? false
? titleModel.source === 'litellm'
? `${LITELLM_API_BASE_URL}/v1`
: `${OPENAI_API_BASE_URL}`
: `${OLLAMA_API_BASE_URL}/v1`
); );
if (title) { return title;
await setChatTitle(_chatId, title);
}
} else { } else {
await setChatTitle(_chatId, `${userPrompt}`); return `${userPrompt}`;
} }
}; };
@ -789,8 +804,10 @@
title = _title; title = _title;
} }
chat = await updateChatById(localStorage.token, _chatId, { title: _title }); if ($settings.saveChatHistory ?? true) {
await chats.set(await getChatList(localStorage.token)); chat = await updateChatById(localStorage.token, _chatId, { title: _title });
await chats.set(await getChatList(localStorage.token));
}
}; };
const getTags = async () => { const getTags = async () => {
@ -843,7 +860,7 @@
shareEnabled={messages.length > 0} shareEnabled={messages.length > 0}
initNewChat={async () => { initNewChat={async () => {
if (currentRequestId !== null) { if (currentRequestId !== null) {
await cancelChatCompletion(localStorage.token, currentRequestId); await cancelOllamaRequest(localStorage.token, currentRequestId);
currentRequestId = null; currentRequestId = null;
} }

View file

@ -13,7 +13,7 @@
} from '$lib/constants'; } from '$lib/constants';
import { WEBUI_NAME, config, user, models, settings } from '$lib/stores'; import { WEBUI_NAME, config, user, models, settings } from '$lib/stores';
import { cancelChatCompletion, generateChatCompletion } from '$lib/apis/ollama'; import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama';
import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import { generateOpenAIChatCompletion } from '$lib/apis/openai';
import { splitStream } from '$lib/utils'; import { splitStream } from '$lib/utils';
@ -52,7 +52,7 @@
// const cancelHandler = async () => { // const cancelHandler = async () => {
// if (currentRequestId) { // if (currentRequestId) {
// const res = await cancelChatCompletion(localStorage.token, currentRequestId); // const res = await cancelOllamaRequest(localStorage.token, currentRequestId);
// currentRequestId = null; // currentRequestId = null;
// loading = false; // loading = false;
// } // }
@ -95,7 +95,7 @@
const { value, done } = await reader.read(); const { value, done } = await reader.read();
if (done || stopResponseFlag) { if (done || stopResponseFlag) {
if (stopResponseFlag) { if (stopResponseFlag) {
await cancelChatCompletion(localStorage.token, currentRequestId); await cancelOllamaRequest(localStorage.token, currentRequestId);
} }
currentRequestId = null; currentRequestId = null;
@ -181,7 +181,7 @@
const { value, done } = await reader.read(); const { value, done } = await reader.read();
if (done || stopResponseFlag) { if (done || stopResponseFlag) {
if (stopResponseFlag) { if (stopResponseFlag) {
await cancelChatCompletion(localStorage.token, currentRequestId); await cancelOllamaRequest(localStorage.token, currentRequestId);
} }
currentRequestId = null; currentRequestId = null;