forked from open-webui/open-webui
Merge branch 'dev' into embedding-model-fix-and-manual-update
This commit is contained in:
commit
506a061387
60 changed files with 1906 additions and 520 deletions
|
@ -28,6 +28,7 @@ from config import (
|
|||
UPLOAD_DIR,
|
||||
WHISPER_MODEL,
|
||||
WHISPER_MODEL_DIR,
|
||||
DEVICE_TYPE,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -42,6 +43,10 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# setting device type for whisper model
|
||||
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
|
||||
log.info(f"whisper_device_type: {whisper_device_type}")
|
||||
|
||||
|
||||
@app.post("/transcribe")
|
||||
def transcribe(
|
||||
|
@ -66,7 +71,7 @@ def transcribe(
|
|||
|
||||
model = WhisperModel(
|
||||
WHISPER_MODEL,
|
||||
device="auto",
|
||||
device=whisper_device_type,
|
||||
compute_type="int8",
|
||||
download_root=WHISPER_MODEL_DIR,
|
||||
)
|
||||
|
|
|
@ -215,7 +215,8 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
|
|||
|
||||
if len(responses) > 0:
|
||||
lowest_version = min(
|
||||
responses, key=lambda x: tuple(map(int, x["version"].split(".")))
|
||||
responses,
|
||||
key=lambda x: tuple(map(int, x["version"].split("-")[0].split("."))),
|
||||
)
|
||||
|
||||
return {"version": lowest_version["version"]}
|
||||
|
|
|
@ -58,8 +58,8 @@ from config import (
|
|||
UPLOAD_DIR,
|
||||
DOCS_DIR,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
DEVICE_TYPE,
|
||||
CHROMA_CLIENT,
|
||||
CHUNK_SIZE,
|
||||
CHUNK_OVERLAP,
|
||||
|
@ -86,7 +86,7 @@ app.state.TOP_K = 4
|
|||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
|
||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -154,7 +154,7 @@ async def update_embedding_model(
|
|||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
|
||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -471,25 +471,11 @@ def store_doc(
|
|||
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
try:
|
||||
is_valid_filename = True
|
||||
unsanitized_filename = file.filename
|
||||
if re.search(r'[\\/:"\*\?<>|\n\t ]', unsanitized_filename) is not None:
|
||||
is_valid_filename = False
|
||||
filename = os.path.basename(unsanitized_filename)
|
||||
|
||||
unvalidated_file_path = f"{UPLOAD_DIR}/{unsanitized_filename}"
|
||||
dereferenced_file_path = str(Path(unvalidated_file_path).resolve(strict=False))
|
||||
if not dereferenced_file_path.startswith(UPLOAD_DIR):
|
||||
is_valid_filename = False
|
||||
file_path = f"{UPLOAD_DIR}/{filename}"
|
||||
|
||||
if is_valid_filename:
|
||||
file_path = dereferenced_file_path
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
|
||||
filename = file.filename
|
||||
contents = file.file.read()
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(contents)
|
||||
|
@ -500,7 +486,7 @@ def store_doc(
|
|||
collection_name = calculate_sha256(f)[:63]
|
||||
f.close()
|
||||
|
||||
loader, known_type = get_loader(file.filename, file.content_type, file_path)
|
||||
loader, known_type = get_loader(filename, file.content_type, file_path)
|
||||
data = loader.load()
|
||||
|
||||
try:
|
||||
|
|
|
@ -86,6 +86,7 @@ class SignupForm(BaseModel):
|
|||
name: str
|
||||
email: str
|
||||
password: str
|
||||
profile_image_url: Optional[str] = "/user.png"
|
||||
|
||||
|
||||
class AuthsTable:
|
||||
|
@ -94,7 +95,12 @@ class AuthsTable:
|
|||
self.db.create_tables([Auth])
|
||||
|
||||
def insert_new_auth(
|
||||
self, email: str, password: str, name: str, role: str = "pending"
|
||||
self,
|
||||
email: str,
|
||||
password: str,
|
||||
name: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
) -> Optional[UserModel]:
|
||||
log.info("insert_new_auth")
|
||||
|
||||
|
@ -105,7 +111,7 @@ class AuthsTable:
|
|||
)
|
||||
result = Auth.create(**auth.model_dump())
|
||||
|
||||
user = Users.insert_new_user(id, name, email, role)
|
||||
user = Users.insert_new_user(id, name, email, profile_image_url, role)
|
||||
|
||||
if result and user:
|
||||
return user
|
||||
|
|
|
@ -206,6 +206,18 @@ class ChatTable:
|
|||
except:
|
||||
return None
|
||||
|
||||
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
chat = Chat.get(Chat.share_id == id)
|
||||
|
||||
if chat:
|
||||
chat = Chat.get(Chat.id == id)
|
||||
return ChatModel(**model_to_dict(chat))
|
||||
else:
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
|
||||
|
|
|
@ -31,7 +31,7 @@ class UserModel(BaseModel):
|
|||
name: str
|
||||
email: str
|
||||
role: str = "pending"
|
||||
profile_image_url: str = "/user.png"
|
||||
profile_image_url: str
|
||||
timestamp: int # timestamp in epoch
|
||||
api_key: Optional[str] = None
|
||||
|
||||
|
@ -59,7 +59,12 @@ class UsersTable:
|
|||
self.db.create_tables([User])
|
||||
|
||||
def insert_new_user(
|
||||
self, id: str, name: str, email: str, role: str = "pending"
|
||||
self,
|
||||
id: str,
|
||||
name: str,
|
||||
email: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
) -> Optional[UserModel]:
|
||||
user = UserModel(
|
||||
**{
|
||||
|
@ -67,7 +72,7 @@ class UsersTable:
|
|||
"name": name,
|
||||
"email": email,
|
||||
"role": role,
|
||||
"profile_image_url": "/user.png",
|
||||
"profile_image_url": profile_image_url,
|
||||
"timestamp": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
|
|
@ -163,7 +163,11 @@ async def signup(request: Request, form_data: SignupForm):
|
|||
)
|
||||
hashed = get_password_hash(form_data.password)
|
||||
user = Auths.insert_new_auth(
|
||||
form_data.email.lower(), hashed, form_data.name, role
|
||||
form_data.email.lower(),
|
||||
hashed,
|
||||
form_data.name,
|
||||
form_data.profile_image_url,
|
||||
role,
|
||||
)
|
||||
|
||||
if user:
|
||||
|
|
|
@ -251,7 +251,15 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
|
|||
|
||||
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
|
||||
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
|
||||
chat = Chats.get_chat_by_id(share_id)
|
||||
if user.role == "pending":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role == "user":
|
||||
chat = Chats.get_chat_by_share_id(share_id)
|
||||
elif user.role == "admin":
|
||||
chat = Chats.get_chat_by_id(share_id)
|
||||
|
||||
if chat:
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
|
|
|
@ -1,16 +1,11 @@
|
|||
from fastapi import APIRouter, UploadFile, File, BackgroundTasks
|
||||
from fastapi import APIRouter, UploadFile, File, Response
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from starlette.responses import StreamingResponse, FileResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from fpdf import FPDF
|
||||
import markdown
|
||||
import requests
|
||||
import os
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
|
||||
from utils.utils import get_admin_user
|
||||
|
@ -18,7 +13,7 @@ from utils.misc import calculate_sha256, get_gravatar_url
|
|||
|
||||
from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
from typing import List
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -41,6 +36,59 @@ async def get_html_from_markdown(
|
|||
return {"html": markdown.markdown(form_data.md)}
|
||||
|
||||
|
||||
class ChatForm(BaseModel):
|
||||
title: str
|
||||
messages: List[dict]
|
||||
|
||||
|
||||
@router.post("/pdf")
|
||||
async def download_chat_as_pdf(
|
||||
form_data: ChatForm,
|
||||
):
|
||||
pdf = FPDF()
|
||||
pdf.add_page()
|
||||
|
||||
STATIC_DIR = "./static"
|
||||
FONTS_DIR = f"{STATIC_DIR}/fonts"
|
||||
|
||||
pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf")
|
||||
pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf")
|
||||
pdf.add_font("NotoSans", "i", f"{FONTS_DIR}/NotoSans-Italic.ttf")
|
||||
pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf")
|
||||
pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf")
|
||||
|
||||
pdf.set_font("NotoSans", size=12)
|
||||
pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP"])
|
||||
|
||||
pdf.set_auto_page_break(auto=True, margin=15)
|
||||
|
||||
# Adjust the effective page width for multi_cell
|
||||
effective_page_width = (
|
||||
pdf.w - 2 * pdf.l_margin - 10
|
||||
) # Subtracted an additional 10 for extra padding
|
||||
|
||||
# Add chat messages
|
||||
for message in form_data.messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
pdf.set_font("NotoSans", "B", size=14) # Bold for the role
|
||||
pdf.multi_cell(effective_page_width, 10, f"{role.upper()}", 0, "L")
|
||||
pdf.ln(1) # Extra space between messages
|
||||
|
||||
pdf.set_font("NotoSans", size=10) # Regular for content
|
||||
pdf.multi_cell(effective_page_width, 6, content, 0, "L")
|
||||
pdf.ln(1.5) # Extra space between messages
|
||||
|
||||
# Save the pdf with name .pdf
|
||||
pdf_bytes = pdf.output()
|
||||
|
||||
return Response(
|
||||
content=bytes(pdf_bytes),
|
||||
media_type="application/pdf",
|
||||
headers={"Content-Disposition": f"attachment;filename=chat.pdf"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/db/download")
|
||||
async def download_db(user=Depends(get_admin_user)):
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue