forked from open-webui/open-webui
backend support api key
This commit is contained in:
parent
ac294a74e7
commit
81e928030f
4 changed files with 100 additions and 8 deletions
|
@ -24,6 +24,7 @@ class Auth(Model):
|
|||
email = CharField()
|
||||
password = CharField()
|
||||
active = BooleanField()
|
||||
api_key = CharField(null=True, unique=True)
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
@ -34,6 +35,7 @@ class AuthModel(BaseModel):
|
|||
email: str
|
||||
password: str
|
||||
active: bool = True
|
||||
api_key: Optional[str] = None
|
||||
|
||||
|
||||
####################
|
||||
|
@ -45,6 +47,8 @@ class Token(BaseModel):
|
|||
token: str
|
||||
token_type: str
|
||||
|
||||
class ApiKey(BaseModel):
|
||||
api_key: Optional[str] = None
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
|
@ -122,6 +126,21 @@ class AuthsTable:
|
|||
except:
|
||||
return None
|
||||
|
||||
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_api_key: {api_key}")
|
||||
# if no api_key, return None
|
||||
if not api_key:
|
||||
return None
|
||||
try:
|
||||
auth = Auth.get(Auth.api_key == api_key, Auth.active == True)
|
||||
if auth:
|
||||
user = Users.get_user_by_id(auth.id)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
|
||||
try:
|
||||
query = Auth.update(password=new_password).where(Auth.id == id)
|
||||
|
@ -140,6 +159,22 @@ class AuthsTable:
|
|||
except:
|
||||
return False
|
||||
|
||||
def update_api_key_by_id(self, id: str, api_key: str) -> str:
|
||||
try:
|
||||
query = Auth.update(api_key=api_key).where(Auth.id == id)
|
||||
result = query.execute()
|
||||
|
||||
return True if result == 1 else False
|
||||
except:
|
||||
return False
|
||||
|
||||
def get_api_key_by_id(self, id: str) -> Optional[str]:
|
||||
try:
|
||||
auth = Auth.get(Auth.id == id)
|
||||
return auth.api_key
|
||||
except:
|
||||
return None
|
||||
|
||||
def delete_auth_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
# Delete User
|
||||
|
|
|
@ -1,12 +1,8 @@
|
|||
from fastapi import Response, Request
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Union
|
||||
from fastapi import Request
|
||||
from fastapi import Depends, HTTPException, status
|
||||
|
||||
from fastapi import APIRouter, status
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import time
|
||||
import uuid
|
||||
import re
|
||||
|
||||
from apps.web.models.auths import (
|
||||
|
@ -17,6 +13,7 @@ from apps.web.models.auths import (
|
|||
UserResponse,
|
||||
SigninResponse,
|
||||
Auths,
|
||||
ApiKey
|
||||
)
|
||||
from apps.web.models.users import Users
|
||||
|
||||
|
@ -25,6 +22,7 @@ from utils.utils import (
|
|||
get_current_user,
|
||||
get_admin_user,
|
||||
create_token,
|
||||
create_api_key
|
||||
)
|
||||
from utils.misc import parse_duration, validate_email_format
|
||||
from utils.webhook import post_webhook
|
||||
|
@ -249,3 +247,40 @@ async def update_token_expires_duration(
|
|||
return request.app.state.JWT_EXPIRES_IN
|
||||
else:
|
||||
return request.app.state.JWT_EXPIRES_IN
|
||||
|
||||
|
||||
############################
|
||||
# API Key
|
||||
############################
|
||||
|
||||
|
||||
# create api key
|
||||
@router.post("/api_key", response_model=ApiKey)
|
||||
async def create_api_key_(user=Depends(get_current_user)):
|
||||
api_key = create_api_key()
|
||||
success = Auths.update_api_key_by_id(user.id, api_key)
|
||||
if success:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
|
||||
|
||||
|
||||
# delete api key
|
||||
@router.delete("/api_key", response_model=bool)
|
||||
async def delete_api_key(user=Depends(get_current_user)):
|
||||
success = Auths.update_api_key_by_id(user.id, None)
|
||||
return success
|
||||
|
||||
|
||||
# get api key
|
||||
@router.get("/api_key", response_model=ApiKey)
|
||||
async def get_api_key(user=Depends(get_current_user)):
|
||||
api_key = Auths.get_api_key_by_id(user.id, None)
|
||||
if api_key:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
||||
|
|
|
@ -58,5 +58,6 @@ class ERROR_MESSAGES(str, Enum):
|
|||
RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
|
||||
|
||||
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
|
||||
OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found"
|
||||
OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
|
||||
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
|
||||
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi import HTTPException, status, Depends
|
||||
from apps.web.models.users import Users
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Union, Optional
|
||||
from constants import ERROR_MESSAGES
|
||||
|
@ -8,6 +9,7 @@ from passlib.context import CryptContext
|
|||
from datetime import datetime, timedelta
|
||||
import requests
|
||||
import jwt
|
||||
import uuid
|
||||
import logging
|
||||
import config
|
||||
|
||||
|
@ -58,6 +60,11 @@ def extract_token_from_auth_header(auth_header: str):
|
|||
return auth_header[len("Bearer ") :]
|
||||
|
||||
|
||||
def create_api_key():
|
||||
key = str(uuid.uuid4()).replace("-", "")
|
||||
return f"sk-{key}"
|
||||
|
||||
|
||||
def get_http_authorization_cred(auth_header: str):
|
||||
try:
|
||||
scheme, credentials = auth_header.split(" ")
|
||||
|
@ -69,6 +76,10 @@ def get_http_authorization_cred(auth_header: str):
|
|||
def get_current_user(
|
||||
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||
):
|
||||
# auth by api key
|
||||
if auth_token.credentials.startswith("sk-"):
|
||||
return get_current_user_by_api_key(auth_token.credentials)
|
||||
# auth by jwt token
|
||||
data = decode_token(auth_token.credentials)
|
||||
if data != None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
|
@ -84,6 +95,16 @@ def get_current_user(
|
|||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
def get_current_user_by_api_key(api_key: str):
|
||||
from apps.web.models.auths import Auths
|
||||
|
||||
user = Auths.authenticate_user_by_api_key(api_key)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
return user
|
||||
|
||||
def get_verified_user(user=Depends(get_current_user)):
|
||||
if user.role not in {"user", "admin"}:
|
||||
|
|
Loading…
Reference in a new issue