Merge pull request #408 from ollama-webui/main

rag
This commit is contained in:
Timothy Jaeryang Baek 2024-01-06 17:33:52 -08:00 committed by GitHub
commit 7071716f54
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
57 changed files with 2535 additions and 1500 deletions

View file

@ -1,119 +1,111 @@
from flask import Flask, request, Response, jsonify
from flask_cors import CORS
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
import requests
import json
from pydantic import BaseModel
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token
from utils.utils import decode_token, get_current_user
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
app = Flask(__name__)
CORS(
app
) # Enable Cross-Origin Resource Sharing (CORS) to allow requests from different domains
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define the target server URL
TARGET_SERVER_URL = OLLAMA_API_BASE_URL
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
@app.route("/", defaults={"path": ""}, methods=["GET", "POST", "PUT", "DELETE"])
@app.route("/<path:path>", methods=["GET", "POST", "PUT", "DELETE"])
def proxy(path):
# Combine the base URL of the target server with the requested path
target_url = f"{TARGET_SERVER_URL}/{path}"
print(target_url)
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
# Get data from the original request
data = request.get_data()
class UrlUpdateForm(BaseModel):
url: str
@app.post("/url/update")
async def update_ollama_api_url(
form_data: UrlUpdateForm, user=Depends(get_current_user)
):
if user and user.role == "admin":
app.state.OLLAMA_API_BASE_URL = form_data.url
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
body = await request.body()
headers = dict(request.headers)
# Basic RBAC support
if WEBUI_AUTH:
if "Authorization" in headers:
_, credentials = headers["Authorization"].split()
token_data = decode_token(credentials)
if token_data is None or "email" not in token_data:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
user = Users.get_user_by_email(token_data["email"])
if user:
# Only user and admin roles can access
if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]:
# Only admin role can perform actions above
if user.role == "admin":
pass
else:
return (
jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}),
401,
)
else:
pass
else:
return jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), 401
else:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
else:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin":
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
else:
pass
r = None
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
headers.pop("Host", None)
headers.pop("Authorization", None)
headers.pop("Origin", None)
headers.pop("Referer", None)
r = None
def get_request():
nonlocal r
try:
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
# Make a request to the target server
r = requests.request(
method=request.method,
url=target_url,
data=data,
headers=headers,
stream=True, # Enable streaming for server-sent events
)
r.raise_for_status()
# Proxy the target server's response to the client
def generate():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
response = Response(generate(), status=r.status_code)
# Copy headers from the target server's response to the client's response
for key, value in r.headers.items():
response.headers[key] = value
return response
return await run_in_threadpool(get_request)
except Exception as e:
print(e)
error_detail = "Ollama WebUI: Server Connection Error"
if r != None:
print(r.text)
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
print(res)
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
return (
jsonify(
{
"detail": error_detail,
"message": str(e),
}
),
400,
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
if __name__ == "__main__":
app.run(debug=True)

View file

@ -0,0 +1,127 @@
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import requests
import json
from pydantic import BaseModel
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
import aiohttp
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
class UrlUpdateForm(BaseModel):
url: str
@app.post("/url/update")
async def update_ollama_api_url(
form_data: UrlUpdateForm, user=Depends(get_current_user)
):
if user and user.role == "admin":
app.state.OLLAMA_API_BASE_URL = form_data.url
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
# async def fetch_sse(method, target_url, body, headers):
# async with aiohttp.ClientSession() as session:
# try:
# async with session.request(
# method, target_url, data=body, headers=headers
# ) as response:
# print(response.status)
# async for line in response.content:
# yield line
# except Exception as e:
# print(e)
# error_detail = "Ollama WebUI: Server Connection Error"
# yield json.dumps({"error": error_detail, "message": str(e)}).encode()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
print(target_url)
body = await request.body()
headers = dict(request.headers)
if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin":
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
headers.pop("Host", None)
headers.pop("Authorization", None)
headers.pop("Origin", None)
headers.pop("Referer", None)
session = aiohttp.ClientSession()
response = None
try:
response = await session.request(
request.method, target_url, data=body, headers=headers
)
print(response)
if not response.ok:
data = await response.json()
print(data)
response.raise_for_status()
async def generate():
async for line in response.content:
print(line)
yield line
await session.close()
return StreamingResponse(generate(), response.status)
except Exception as e:
print(e)
error_detail = "Ollama WebUI: Server Connection Error"
if response is not None:
try:
res = await response.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
await session.close()
raise HTTPException(
status_code=response.status if response else 500,
detail=error_detail,
)

143
backend/apps/openai/main.py Normal file
View file

@ -0,0 +1,143 @@
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
import requests
import json
from pydantic import BaseModel
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = OPENAI_API_KEY
class UrlUpdateForm(BaseModel):
url: str
class KeyUpdateForm(BaseModel):
key: str
@app.get("/url")
async def get_openai_url(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm,
user=Depends(get_current_user)):
if user and user.role == "admin":
app.state.OPENAI_API_BASE_URL = form_data.url
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.get("/key")
async def get_openai_key(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/key/update")
async def update_openai_key(form_data: KeyUpdateForm,
user=Depends(get_current_user)):
if user and user.role == "admin":
app.state.OPENAI_API_KEY = form_data.key
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
print(target_url, app.state.OPENAI_API_KEY)
if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
body = await request.body()
# headers = dict(request.headers)
# print(headers)
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
try:
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
else:
# For non-SSE, read the response and return it
# response_data = (
# r.json()
# if r.headers.get("Content-Type", "")
# == "application/json"
# else r.text
# )
response_data = r.json()
print(type(response_data))
if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
response_data["data"] = list(
filter(lambda model: "gpt" in model["id"],
response_data["data"]))
return response_data
except Exception as e:
print(e)
error_detail = "Ollama WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status_code, detail=error_detail)

View file

@ -22,10 +22,11 @@ app.add_middleware(
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
app.include_router(modelfiles.router,
prefix="/modelfiles",
tags=["modelfiles"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])

View file

@ -4,7 +4,6 @@ import time
import uuid
from peewee import *
from apps.web.models.users import UserModel, Users
from utils.utils import (
verify_password,
@ -123,6 +122,15 @@ class AuthsTable:
except:
return False
def update_email_by_id(self, id: str, email: str) -> bool:
try:
query = Auth.update(email=email).where(Auth.id == id)
result = query.execute()
return True if result == 1 else False
except:
return False
def delete_auth_by_id(self, id: str) -> bool:
try:
# Delete User

View file

@ -3,14 +3,12 @@ from typing import List, Union, Optional
from peewee import *
from playhouse.shortcuts import model_to_dict
import json
import uuid
import time
from apps.web.internal.db import DB
####################
# Chat DB Schema
####################
@ -62,23 +60,23 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable:
def __init__(self, db):
self.db = db
db.create_tables([Chat])
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
def insert_new_chat(self, user_id: str,
form_data: ChatForm) -> Optional[ChatModel]:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": form_data.chat["title"]
if "title" in form_data.chat
else "New Chat",
"title": form_data.chat["title"] if "title" in
form_data.chat else "New Chat",
"chat": json.dumps(form_data.chat),
"timestamp": int(time.time()),
}
)
})
result = Chat.create(**chat.model_dump())
return chat if result else None
@ -111,27 +109,25 @@ class ChatTable:
except:
return None
def get_chat_lists_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]:
def get_chat_lists_by_user_id(self,
user_id: str,
skip: int = 0,
limit: int = 50) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.timestamp.desc())
ChatModel(**model_to_dict(chat)) for chat in Chat.select().where(
Chat.user_id == user_id).order_by(Chat.timestamp.desc())
# .limit(limit)
# .offset(skip)
]
def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.timestamp.desc())
ChatModel(**model_to_dict(chat)) for chat in Chat.select().where(
Chat.user_id == user_id).order_by(Chat.timestamp.desc())
]
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
def get_chat_by_id_and_user_id(self, id: str,
user_id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
return ChatModel(**model_to_dict(chat))
@ -146,7 +142,8 @@ class ChatTable:
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
query = Chat.delete().where((Chat.id == id)
& (Chat.user_id == user_id))
query.execute() # Remove the rows, return number of rows removed.
return True

View file

@ -58,13 +58,14 @@ class ModelfileResponse(BaseModel):
class ModelfilesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Modelfile])
def insert_new_modelfile(
self, user_id: str, form_data: ModelfileForm
) -> Optional[ModelfileModel]:
self, user_id: str,
form_data: ModelfileForm) -> Optional[ModelfileModel]:
if "tagName" in form_data.modelfile:
modelfile = ModelfileModel(
**{
@ -72,8 +73,7 @@ class ModelfilesTable:
"tag_name": form_data.modelfile["tagName"],
"modelfile": json.dumps(form_data.modelfile),
"timestamp": int(time.time()),
}
)
})
try:
result = Modelfile.create(**modelfile.model_dump())
@ -87,28 +87,29 @@ class ModelfilesTable:
else:
return None
def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]:
def get_modelfile_by_tag_name(self,
tag_name: str) -> Optional[ModelfileModel]:
try:
modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile))
except:
return None
def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]:
def get_modelfiles(self,
skip: int = 0,
limit: int = 50) -> List[ModelfileResponse]:
return [
ModelfileResponse(
**{
**model_to_dict(modelfile),
"modelfile": json.loads(modelfile.modelfile),
}
)
for modelfile in Modelfile.select()
"modelfile":
json.loads(modelfile.modelfile),
}) for modelfile in Modelfile.select()
# .limit(limit).offset(skip)
]
def update_modelfile_by_tag_name(
self, tag_name: str, modelfile: dict
) -> Optional[ModelfileModel]:
self, tag_name: str, modelfile: dict) -> Optional[ModelfileModel]:
try:
query = Modelfile.update(
modelfile=json.dumps(modelfile),

View file

@ -47,13 +47,13 @@ class PromptForm(BaseModel):
class PromptsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Prompt])
def insert_new_prompt(
self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]:
def insert_new_prompt(self, user_id: str,
form_data: PromptForm) -> Optional[PromptModel]:
prompt = PromptModel(
**{
"user_id": user_id,
@ -61,8 +61,7 @@ class PromptsTable:
"title": form_data.title,
"content": form_data.content,
"timestamp": int(time.time()),
}
)
})
try:
result = Prompt.create(**prompt.model_dump())
@ -82,14 +81,13 @@ class PromptsTable:
def get_prompts(self) -> List[PromptModel]:
return [
PromptModel(**model_to_dict(prompt))
for prompt in Prompt.select()
PromptModel(**model_to_dict(prompt)) for prompt in Prompt.select()
# .limit(limit).offset(skip)
]
def update_prompt_by_command(
self, command: str, form_data: PromptForm
) -> Optional[PromptModel]:
self, command: str,
form_data: PromptForm) -> Optional[PromptModel]:
try:
query = Prompt.update(
title=form_data.title,

View file

@ -8,7 +8,6 @@ from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
from apps.web.models.chats import Chats
####################
# User DB Schema
####################
@ -45,6 +44,13 @@ class UserRoleUpdateForm(BaseModel):
role: str
class UserUpdateForm(BaseModel):
name: str
email: str
profile_image_url: str
password: Optional[str] = None
class UsersTable:
def __init__(self, db):
self.db = db
@ -102,6 +108,16 @@ class UsersTable:
except:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
query = User.update(**updated).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
def delete_user_by_id(self, id: str) -> bool:
try:
# Delete User Chats

View file

@ -8,7 +8,6 @@ from pydantic import BaseModel
import time
import uuid
from apps.web.models.auths import (
SigninForm,
SignupForm,
@ -19,12 +18,10 @@ from apps.web.models.auths import (
)
from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
router = APIRouter()
############################
@ -49,9 +46,8 @@ async def get_session_user(user=Depends(get_current_user)):
@router.post("/update/password", response_model=bool)
async def update_password(
form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
):
async def update_password(form_data: UpdatePasswordForm,
session_user=Depends(get_current_user)):
if session_user:
user = Auths.authenticate_user(session_user.email, form_data.password)
@ -101,9 +97,8 @@ async def signup(request: Request, form_data: SignupForm):
try:
role = "admin" if Users.get_num_users() == 0 else "pending"
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(
form_data.email.lower(), hashed, form_data.name, role
)
user = Auths.insert_new_auth(form_data.email.lower(),
hashed, form_data.name, role)
if user:
token = create_token(data={"email": user.email})
@ -120,14 +115,15 @@ async def signup(request: Request, form_data: SignupForm):
}
else:
raise HTTPException(
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
)
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
raise HTTPException(500,
detail=ERROR_MESSAGES.DEFAULT(err))
else:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
raise HTTPException(400,
detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
else:
raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)

View file

@ -17,8 +17,7 @@ from apps.web.models.chats import (
)
from utils.utils import (
bearer_scheme,
)
bearer_scheme, )
from constants import ERROR_MESSAGES
router = APIRouter()
@ -30,8 +29,7 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
user=Depends(get_current_user), skip: int = 0, limit: int = 50):
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
@ -43,8 +41,9 @@ async def get_user_chats(
@router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(user=Depends(get_current_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(user.id)
ChatResponse(**{
**chat.model_dump(), "chat": json.loads(chat.chat)
}) for chat in Chats.get_all_chats_by_user_id(user.id)
]
@ -69,11 +68,12 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**{
**chat.model_dump(), "chat": json.loads(chat.chat)
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND)
############################
@ -82,15 +82,17 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
@router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(
id: str, form_data: ChatForm, user=Depends(get_current_user)
):
async def update_chat_by_id(id: str,
form_data: ChatForm,
user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**{
**chat.model_dump(), "chat": json.loads(chat.chat)
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,

View file

@ -10,7 +10,6 @@ import uuid
from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
@ -28,9 +27,9 @@ class SetDefaultModelsForm(BaseModel):
@router.post("/default/models", response_model=str)
async def set_global_default_models(
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user)
):
async def set_global_default_models(request: Request,
form_data: SetDefaultModelsForm,
user=Depends(get_current_user)):
if user.role == "admin":
request.app.state.DEFAULT_MODELS = form_data.models
return request.app.state.DEFAULT_MODELS

View file

@ -24,7 +24,9 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
async def get_modelfiles(skip: int = 0,
limit: int = 50,
user=Depends(get_current_user)):
return Modelfiles.get_modelfiles(skip, limit)
@ -34,9 +36,8 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_curren
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(
form_data: ModelfileForm, user=Depends(get_current_user)
):
async def create_new_modelfile(form_data: ModelfileForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -49,9 +50,9 @@ async def create_new_modelfile(
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
"modelfile":
json.loads(modelfile.modelfile),
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -65,16 +66,17 @@ async def create_new_modelfile(
@router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depends(get_current_user)):
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
"modelfile":
json.loads(modelfile.modelfile),
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -88,9 +90,8 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depend
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, user=Depends(get_current_user)
):
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -104,15 +105,14 @@ async def update_modelfile_by_tag_name(
}
modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile
)
form_data.tag_name, updated_modelfile)
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
"modelfile":
json.loads(modelfile.modelfile),
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -126,9 +126,8 @@ async def update_modelfile_by_tag_name(
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_current_user)
):
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,

View file

@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user
@ -30,7 +29,8 @@ async def get_prompts(user=Depends(get_current_user)):
@router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)):
async def create_new_prompt(form_data: PromptForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -79,9 +79,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
@router.post("/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_current_user)
):
async def update_prompt_by_command(command: str,
form_data: PromptForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -104,7 +104,8 @@ async def update_prompt_by_command(
@router.delete("/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_current_user)):
async def delete_prompt_by_command(command: str,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,

View file

@ -8,14 +8,12 @@ from pydantic import BaseModel
import time
import uuid
from apps.web.models.users import UserModel, UserRoleUpdateForm, Users
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths
from utils.utils import get_current_user
from utils.utils import get_current_user, get_password_hash
from constants import ERROR_MESSAGES
router = APIRouter()
############################
@ -57,6 +55,62 @@ async def update_user_role(
)
############################
# UpdateUserById
############################
@router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_current_user)
):
if session_user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user = Users.get_user_by_id(user_id)
if user:
if form_data.email.lower() != user.email:
email_user = Users.get_user_by_email(form_data.email.lower())
if email_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.EMAIL_TAKEN,
)
if form_data.password:
hashed = get_password_hash(form_data.password)
print(hashed)
Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(user_id, form_data.email.lower())
updated_user = Users.update_user_by_id(
user_id,
{
"name": form_data.name,
"email": form_data.email.lower(),
"profile_image_url": form_data.profile_image_url,
},
)
if updated_user:
return updated_user
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# DeleteUserById
############################

View file

@ -9,12 +9,10 @@ import os
import aiohttp
import json
from utils.misc import calculate_sha256
from config import OLLAMA_API_BASE_URL
router = APIRouter()
@ -42,7 +40,10 @@ def parse_huggingface_url(hf_url):
return None
async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
async def download_file_stream(url,
file_path,
file_name,
chunk_size=1024 * 1024):
done = False
if os.path.exists(file_path):
@ -56,7 +57,8 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
total_size = int(response.headers.get("content-length",
0)) + current_size
with open(file_path, "ab+") as file:
async for data in response.content.iter_chunked(chunk_size):
@ -89,9 +91,7 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024
@router.get("/download")
async def download(
url: str,
):
async def download(url: str, ):
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
file_name = parse_huggingface_url(url)
@ -161,4 +161,5 @@ async def upload(file: UploadFile = File(...)):
res = {"error": str(e)}
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_write_stream(), media_type="text/event-stream")
return StreamingResponse(file_write_stream(),
media_type="text/event-stream")

View file

@ -19,19 +19,28 @@ ENV = os.environ.get("ENV", "dev")
# OLLAMA_API_BASE_URL
####################################
OLLAMA_API_BASE_URL = os.environ.get(
"OLLAMA_API_BASE_URL", "http://localhost:11434/api"
)
OLLAMA_API_BASE_URL = os.environ.get("OLLAMA_API_BASE_URL",
"http://localhost:11434/api")
if ENV == "prod":
if OLLAMA_API_BASE_URL == "/ollama/api":
OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api"
####################################
# OPENAI_API
####################################
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
if OPENAI_API_BASE_URL == "":
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
####################################
# WEBUI_VERSION
####################################
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.42")
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.50")
####################################
# WEBUI_AUTH (Required for security)

View file

@ -6,6 +6,7 @@ class MESSAGES(str, Enum):
class ERROR_MESSAGES(str, Enum):
def __str__(self) -> str:
return super().__str__()
@ -29,8 +30,8 @@ class ERROR_MESSAGES(str, Enum):
UNAUTHORIZED = "401 Unauthorized"
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
ACTION_PROHIBITED = (
"The requested action has been restricted as a security measure."
)
"The requested action has been restricted as a security measure.")
NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
MALICIOUS = "Unusual activities detected, please try again in a few minutes."

View file

@ -6,12 +6,15 @@ from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
from apps.web.main import app as webui_app
import time
class SPAStaticFiles(StaticFiles):
async def get_response(self, path: str, scope):
try:
return await super().get_response(path, scope)
@ -46,5 +49,9 @@ async def check_url(request: Request, call_next):
app.mount("/api/v1", webui_app)
app.mount("/ollama/api", WSGIMiddleware(ollama_app))
app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files")
app.mount("/ollama/api", ollama_app)
app.mount("/openai/api", openai_app)
app.mount("/",
SPAStaticFiles(directory="../build", html=True),
name="spa-static-files")

View file

@ -8,9 +8,12 @@ from passlib.context import CryptContext
from datetime import datetime, timedelta
import requests
import jwt
import logging
import config
logging.getLogger("passlib").setLevel(logging.ERROR)
JWT_SECRET_KEY = config.WEBUI_JWT_SECRET_KEY
ALGORITHM = "HS256"