Merge branch 'main' into functions

This commit is contained in:
Timothy Jaeryang Baek 2024-02-23 04:51:01 -05:00 committed by GitHub
commit 837feb4e79
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 2439 additions and 481 deletions

1
backend/.gitignore vendored
View file

@ -7,4 +7,5 @@ uploads
_test
Pipfile
data/*
!data/config.json
.webui_secret_key

View file

@ -56,7 +56,7 @@ def transcribe(
model = WhisperModel(
WHISPER_MODEL,
device="cpu",
device="auto",
compute_type="int8",
download_root=WHISPER_MODEL_DIR,
)

193
backend/apps/images/main.py Normal file
View file

@ -0,0 +1,193 @@
import re
import requests
from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
UploadFile,
File,
Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (
get_current_user,
get_admin_user,
)
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
from config import AUTOMATIC1111_BASE_URL
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
app.state.IMAGE_SIZE = "512x512"
@app.get("/enabled", response_model=bool)
async def get_enable_status(request: Request, user=Depends(get_admin_user)):
return app.state.ENABLED
@app.get("/enabled/toggle", response_model=bool)
async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
try:
r = requests.head(app.state.AUTOMATIC1111_BASE_URL)
app.state.ENABLED = not app.state.ENABLED
return app.state.ENABLED
except Exception as e:
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
class UrlUpdateForm(BaseModel):
url: str
@app.get("/url")
async def get_openai_url(user=Depends(get_admin_user)):
return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
@app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
if form_data.url == "":
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else:
app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"status": True,
}
class ImageSizeUpdateForm(BaseModel):
size: str
@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
@app.post("/size/update")
async def update_image_size(
form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
):
pattern = r"^\d+x\d+$" # Regular expression pattern
if re.match(pattern, form_data.size):
app.state.IMAGE_SIZE = form_data.size
return {
"IMAGE_SIZE": app.state.IMAGE_SIZE,
"status": True,
}
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
)
@app.get("/models")
def get_models(user=Depends(get_current_user)):
try:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
models = r.json()
return models
except Exception as e:
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
try:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
return {"model": options["sd_model_checkpoint"]}
except Exception as e:
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
class UpdateModelForm(BaseModel):
model: str
def set_model_handler(model: str):
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
)
return options
@app.post("/models/default/update")
def update_default_model(
form_data: UpdateModelForm,
user=Depends(get_current_user),
):
return set_model_handler(form_data.model)
class GenerateImageForm(BaseModel):
model: Optional[str] = None
prompt: str
n: int = 1
size: str = "512x512"
negative_prompt: Optional[str] = None
@app.post("/generations")
def generate_image(
form_data: GenerateImageForm,
user=Depends(get_current_user),
):
print(form_data)
try:
if form_data.model:
set_model_handler(form_data.model)
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
data = {
"prompt": form_data.prompt,
"batch_size": form_data.n,
"width": width,
"height": height,
}
if form_data.negative_prompt != None:
data["negative_prompt"] = form_data.negative_prompt
print(data)
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
)
return r.json()
except Exception as e:
print(e)
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))

View file

@ -1,6 +1,5 @@
from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
@ -14,7 +13,8 @@ import os, shutil
from pathlib import Path
from typing import List
# from chromadb.utils import embedding_functions
from sentence_transformers import SentenceTransformer
from chromadb.utils import embedding_functions
from langchain_community.document_loaders import (
WebBaseLoader,
@ -30,16 +30,12 @@ from langchain_community.document_loaders import (
UnstructuredExcelLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import Chroma
from pydantic import BaseModel
from typing import Optional
import mimetypes
import uuid
import json
import time
from apps.web.models.documents import (
@ -58,23 +54,37 @@ from utils.utils import get_current_user, get_admin_user
from config import (
UPLOAD_DIR,
DOCS_DIR,
EMBED_MODEL,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
CHROMA_CLIENT,
CHUNK_SIZE,
CHUNK_OVERLAP,
RAG_TEMPLATE,
)
from constants import ERROR_MESSAGES
# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
# model_name=EMBED_MODEL
# )
#
# if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer(
# model_name_or_path=RAG_EMBEDDING_MODEL,
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# )
app = FastAPI()
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
origins = ["*"]
@ -106,7 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool:
metadatas = [doc.metadata for doc in docs]
try:
collection = CHROMA_CLIENT.create_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
@ -126,6 +139,38 @@ async def get_status():
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
@app.get("/embedding/model")
async def get_embedding_model(user=Depends(get_admin_user)):
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
class EmbeddingModelUpdateForm(BaseModel):
embedding_model: str
@app.post("/embedding/model/update")
async def update_embedding_model(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
@ -190,8 +235,10 @@ def query_doc(
user=Depends(get_current_user),
):
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result
@ -263,9 +310,12 @@ def query_collection(
for collection_name in form_data.collection_names:
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
result = collection.query(
query_texts=[form_data.query], n_results=form_data.k
)

View file

@ -26,6 +26,8 @@ app = FastAPI()
origins = ["*"]
app.state.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.JWT_EXPIRES_IN = "-1"
app.state.DEFAULT_MODELS = DEFAULT_MODELS
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
@ -55,7 +57,6 @@ app.include_router(utils.router, prefix="/utils", tags=["utils"])
async def get_status():
return {
"status": True,
"version": WEBUI_VERSION,
"auth": WEBUI_AUTH,
"default_models": app.state.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS,

View file

@ -7,6 +7,7 @@ from fastapi import APIRouter, status
from pydantic import BaseModel
import time
import uuid
import re
from apps.web.models.auths import (
SigninForm,
@ -25,7 +26,7 @@ from utils.utils import (
get_admin_user,
create_token,
)
from utils.misc import get_gravatar_url, validate_email_format
from utils.misc import parse_duration, validate_email_format
from constants import ERROR_MESSAGES
router = APIRouter()
@ -95,10 +96,13 @@ async def update_password(
@router.post("/signin", response_model=SigninResponse)
async def signin(form_data: SigninForm):
async def signin(request: Request, form_data: SigninForm):
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if user:
token = create_token(data={"id": user.id})
token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
)
return {
"token": token,
@ -145,7 +149,10 @@ async def signup(request: Request, form_data: SignupForm):
)
if user:
token = create_token(data={"id": user.id})
token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
)
# response.set_cookie(key='token', value=token, httponly=True)
return {
@ -200,3 +207,33 @@ async def update_default_user_role(
if form_data.role in ["pending", "user", "admin"]:
request.app.state.DEFAULT_USER_ROLE = form_data.role
return request.app.state.DEFAULT_USER_ROLE
############################
# JWT Expiration
############################
@router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
return request.app.state.JWT_EXPIRES_IN
class UpdateJWTExpiresDurationForm(BaseModel):
duration: str
@router.post("/token/expires/update")
async def update_token_expires_duration(
request: Request,
form_data: UpdateJWTExpiresDurationForm,
user=Depends(get_admin_user),
):
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
# Check if the input string matches the pattern
if re.match(pattern, form_data.duration):
request.app.state.JWT_EXPIRES_IN = form_data.duration
return request.app.state.JWT_EXPIRES_IN
else:
return request.app.state.JWT_EXPIRES_IN

View file

@ -5,6 +5,10 @@ from secrets import token_bytes
from base64 import b64encode
from constants import ERROR_MESSAGES
from pathlib import Path
import json
import markdown
from bs4 import BeautifulSoup
try:
from dotenv import load_dotenv, find_dotenv
@ -21,6 +25,75 @@ except ImportError:
ENV = os.environ.get("ENV", "dev")
try:
with open(f"../package.json", "r") as f:
PACKAGE_DATA = json.load(f)
except:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
# Function to parse each section
def parse_section(section):
items = []
for li in section.find_all("li"):
# Extract raw HTML string
raw_html = str(li)
# Extract text without HTML tags
text = li.get_text(separator=" ", strip=True)
# Split into title and content
parts = text.split(": ", 1)
title = parts[0].strip() if len(parts) > 1 else ""
content = parts[1].strip() if len(parts) > 1 else text
items.append({"title": title, "content": content, "raw": raw_html})
return items
try:
with open("../CHANGELOG.md", "r") as file:
changelog_content = file.read()
except:
changelog_content = ""
# Convert markdown content to HTML
html_content = markdown.markdown(changelog_content)
# Parse the HTML content
soup = BeautifulSoup(html_content, "html.parser")
# Initialize JSON structure
changelog_json = {}
# Iterate over each version
for version in soup.find_all("h2"):
version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
date = version.get_text().strip().split(" - ")[1]
version_data = {"date": date}
# Find the next sibling that is a h3 tag (section title)
current = version.find_next_sibling()
print(current)
while current and current.name != "h2":
if current.name == "h3":
section_title = current.get_text().lower() # e.g., "added", "fixed"
section_items = parse_section(current.find_next_sibling("ul"))
version_data[section_title] = section_items
# Move to the next element
current = current.find_next_sibling()
changelog_json[version_number] = version_data
CHANGELOG = changelog_json
####################################
# DATA/FRONTEND BUILD DIR
####################################
@ -28,6 +101,12 @@ ENV = os.environ.get("ENV", "dev")
DATA_DIR = str(Path(os.getenv("DATA_DIR", "./data")).resolve())
FRONTEND_BUILD_DIR = str(Path(os.getenv("FRONTEND_BUILD_DIR", "../build")))
try:
with open(f"{DATA_DIR}/config.json", "r") as f:
CONFIG_DATA = json.load(f)
except:
CONFIG_DATA = {}
####################################
# File Upload DIR
####################################
@ -87,9 +166,14 @@ if OPENAI_API_BASE_URL == "":
ENABLE_SIGNUP = os.environ.get("ENABLE_SIGNUP", True)
DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None)
DEFAULT_PROMPT_SUGGESTIONS = os.environ.get(
"DEFAULT_PROMPT_SUGGESTIONS",
[
DEFAULT_PROMPT_SUGGESTIONS = (
CONFIG_DATA["ui"]["prompt_suggestions"]
if "ui" in CONFIG_DATA
and "prompt_suggestions" in CONFIG_DATA["ui"]
and type(CONFIG_DATA["ui"]["prompt_suggestions"]) is list
else [
{
"title": ["Help me study", "vocabulary for a college entrance exam"],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
@ -106,8 +190,10 @@ DEFAULT_PROMPT_SUGGESTIONS = os.environ.get(
"title": ["Show me a code snippet", "of a website's sticky header"],
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.",
},
],
]
)
DEFAULT_USER_ROLE = "pending"
USER_PERMISSIONS = {"chat": {"deletion": True}}
@ -143,7 +229,12 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
####################################
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
EMBED_MODEL = "all-MiniLM-L6-v2"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
"RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
)
CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
@ -172,3 +263,10 @@ Query: [query]"""
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
####################################
# Images
####################################
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")

View file

@ -44,3 +44,6 @@ class ERROR_MESSAGES(str, Enum):
MALICIOUS = "Unusual activities detected, please try again in a few minutes."
PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance."
INCORRECT_FORMAT = (
lambda err="": f"Invalid format. Please use the correct format{err if err else ''}"
)

34
backend/data/config.json Normal file
View file

@ -0,0 +1,34 @@
{
"ui": {
"prompt_suggestions": [
{
"title": [
"Help me study",
"vocabulary for a college entrance exam"
],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."
},
{
"title": [
"Give me ideas",
"for what to do with my kids' art"
],
"content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."
},
{
"title": [
"Tell me a fun fact",
"about the Roman Empire"
],
"content": "Tell me a random fun fact about the Roman Empire"
},
{
"title": [
"Show me a code snippet",
"of a website's sticky header"
],
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript."
}
]
}
}

View file

@ -1,5 +1,9 @@
from bs4 import BeautifulSoup
import json
import markdown
import time
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi import HTTPException
@ -12,11 +16,12 @@ from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
from apps.audio.main import app as audio_app
from apps.functions.main import app as functions_app
from apps.web.main import app as webui_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from config import ENV, FRONTEND_BUILD_DIR
from apps.web.main import app as webui_app
from config import ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
class SPAStaticFiles(StaticFiles):
@ -57,12 +62,30 @@ app.mount("/api/v1", webui_app)
app.mount("/ollama/api", ollama_app)
app.mount("/openai/api", openai_app)
app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app)
app.mount("/rag/api/v1", rag_app)
app.mount("/audio/api/v1", audio_app)
app.mount("/functions/api/v1", functions_app)
@app.get("/api/config")
async def get_app_config():
return {
"status": True,
"version": VERSION,
"images": images_app.state.ENABLED,
"default_models": webui_app.state.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS,
}
@app.get("/api/changelog")
async def get_app_changelog():
return CHANGELOG
app.mount(
"/",
SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),

View file

@ -1,6 +1,8 @@
from pathlib import Path
import hashlib
import re
from datetime import timedelta
from typing import Optional
def get_gravatar_url(email):
@ -76,3 +78,34 @@ def extract_folders_after_data_docs(path):
tags.append("/".join(folders[: idx + 1]))
return tags
def parse_duration(duration: str) -> Optional[timedelta]:
if duration == "-1" or duration == "0":
return None
# Regular expression to find number and unit pairs
pattern = r"(-?\d+(\.\d+)?)(ms|s|m|h|d|w)"
matches = re.findall(pattern, duration)
if not matches:
raise ValueError("Invalid duration string")
total_duration = timedelta()
for number, _, unit in matches:
number = float(number)
if unit == "ms":
total_duration += timedelta(milliseconds=number)
elif unit == "s":
total_duration += timedelta(seconds=number)
elif unit == "m":
total_duration += timedelta(minutes=number)
elif unit == "h":
total_duration += timedelta(hours=number)
elif unit == "d":
total_duration += timedelta(days=number)
elif unit == "w":
total_duration += timedelta(weeks=number)
return total_duration