forked from open-webui/open-webui
backend support api key
This commit is contained in:
parent
ac294a74e7
commit
81e928030f
4 changed files with 100 additions and 8 deletions
|
@ -24,6 +24,7 @@ class Auth(Model):
|
||||||
email = CharField()
|
email = CharField()
|
||||||
password = CharField()
|
password = CharField()
|
||||||
active = BooleanField()
|
active = BooleanField()
|
||||||
|
api_key = CharField(null=True, unique=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
database = DB
|
database = DB
|
||||||
|
@ -34,6 +35,7 @@ class AuthModel(BaseModel):
|
||||||
email: str
|
email: str
|
||||||
password: str
|
password: str
|
||||||
active: bool = True
|
active: bool = True
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -45,6 +47,8 @@ class Token(BaseModel):
|
||||||
token: str
|
token: str
|
||||||
token_type: str
|
token_type: str
|
||||||
|
|
||||||
|
class ApiKey(BaseModel):
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
class UserResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
|
@ -122,6 +126,21 @@ class AuthsTable:
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||||
|
log.info(f"authenticate_user_by_api_key: {api_key}")
|
||||||
|
# if no api_key, return None
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
auth = Auth.get(Auth.api_key == api_key, Auth.active == True)
|
||||||
|
if auth:
|
||||||
|
user = Users.get_user_by_id(auth.id)
|
||||||
|
return user
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
|
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
|
||||||
try:
|
try:
|
||||||
query = Auth.update(password=new_password).where(Auth.id == id)
|
query = Auth.update(password=new_password).where(Auth.id == id)
|
||||||
|
@ -140,6 +159,22 @@ class AuthsTable:
|
||||||
except:
|
except:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def update_api_key_by_id(self, id: str, api_key: str) -> str:
|
||||||
|
try:
|
||||||
|
query = Auth.update(api_key=api_key).where(Auth.id == id)
|
||||||
|
result = query.execute()
|
||||||
|
|
||||||
|
return True if result == 1 else False
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_api_key_by_id(self, id: str) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
auth = Auth.get(Auth.id == id)
|
||||||
|
return auth.api_key
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
def delete_auth_by_id(self, id: str) -> bool:
|
def delete_auth_by_id(self, id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
# Delete User
|
# Delete User
|
||||||
|
|
|
@ -1,12 +1,8 @@
|
||||||
from fastapi import Response, Request
|
from fastapi import Request
|
||||||
from fastapi import Depends, FastAPI, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from apps.web.models.auths import (
|
from apps.web.models.auths import (
|
||||||
|
@ -17,6 +13,7 @@ from apps.web.models.auths import (
|
||||||
UserResponse,
|
UserResponse,
|
||||||
SigninResponse,
|
SigninResponse,
|
||||||
Auths,
|
Auths,
|
||||||
|
ApiKey
|
||||||
)
|
)
|
||||||
from apps.web.models.users import Users
|
from apps.web.models.users import Users
|
||||||
|
|
||||||
|
@ -25,6 +22,7 @@ from utils.utils import (
|
||||||
get_current_user,
|
get_current_user,
|
||||||
get_admin_user,
|
get_admin_user,
|
||||||
create_token,
|
create_token,
|
||||||
|
create_api_key
|
||||||
)
|
)
|
||||||
from utils.misc import parse_duration, validate_email_format
|
from utils.misc import parse_duration, validate_email_format
|
||||||
from utils.webhook import post_webhook
|
from utils.webhook import post_webhook
|
||||||
|
@ -249,3 +247,40 @@ async def update_token_expires_duration(
|
||||||
return request.app.state.JWT_EXPIRES_IN
|
return request.app.state.JWT_EXPIRES_IN
|
||||||
else:
|
else:
|
||||||
return request.app.state.JWT_EXPIRES_IN
|
return request.app.state.JWT_EXPIRES_IN
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# API Key
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
# create api key
|
||||||
|
@router.post("/api_key", response_model=ApiKey)
|
||||||
|
async def create_api_key_(user=Depends(get_current_user)):
|
||||||
|
api_key = create_api_key()
|
||||||
|
success = Auths.update_api_key_by_id(user.id, api_key)
|
||||||
|
if success:
|
||||||
|
return {
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
# delete api key
|
||||||
|
@router.delete("/api_key", response_model=bool)
|
||||||
|
async def delete_api_key(user=Depends(get_current_user)):
|
||||||
|
success = Auths.update_api_key_by_id(user.id, None)
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
# get api key
|
||||||
|
@router.get("/api_key", response_model=ApiKey)
|
||||||
|
async def get_api_key(user=Depends(get_current_user)):
|
||||||
|
api_key = Auths.get_api_key_by_id(user.id, None)
|
||||||
|
if api_key:
|
||||||
|
return {
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
||||||
|
|
|
@ -58,5 +58,6 @@ class ERROR_MESSAGES(str, Enum):
|
||||||
RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
|
RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
|
||||||
|
|
||||||
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
|
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
|
||||||
OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found"
|
OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
|
||||||
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
|
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
|
||||||
|
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
from fastapi import HTTPException, status, Depends
|
from fastapi import HTTPException, status, Depends
|
||||||
from apps.web.models.users import Users
|
from apps.web.models.users import Users
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
@ -8,6 +9,7 @@ from passlib.context import CryptContext
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import requests
|
import requests
|
||||||
import jwt
|
import jwt
|
||||||
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
import config
|
import config
|
||||||
|
|
||||||
|
@ -58,6 +60,11 @@ def extract_token_from_auth_header(auth_header: str):
|
||||||
return auth_header[len("Bearer ") :]
|
return auth_header[len("Bearer ") :]
|
||||||
|
|
||||||
|
|
||||||
|
def create_api_key():
|
||||||
|
key = str(uuid.uuid4()).replace("-", "")
|
||||||
|
return f"sk-{key}"
|
||||||
|
|
||||||
|
|
||||||
def get_http_authorization_cred(auth_header: str):
|
def get_http_authorization_cred(auth_header: str):
|
||||||
try:
|
try:
|
||||||
scheme, credentials = auth_header.split(" ")
|
scheme, credentials = auth_header.split(" ")
|
||||||
|
@ -69,6 +76,10 @@ def get_http_authorization_cred(auth_header: str):
|
||||||
def get_current_user(
|
def get_current_user(
|
||||||
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||||
):
|
):
|
||||||
|
# auth by api key
|
||||||
|
if auth_token.credentials.startswith("sk-"):
|
||||||
|
return get_current_user_by_api_key(auth_token.credentials)
|
||||||
|
# auth by jwt token
|
||||||
data = decode_token(auth_token.credentials)
|
data = decode_token(auth_token.credentials)
|
||||||
if data != None and "id" in data:
|
if data != None and "id" in data:
|
||||||
user = Users.get_user_by_id(data["id"])
|
user = Users.get_user_by_id(data["id"])
|
||||||
|
@ -84,6 +95,16 @@ def get_current_user(
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_current_user_by_api_key(api_key: str):
|
||||||
|
from apps.web.models.auths import Auths
|
||||||
|
|
||||||
|
user = Auths.authenticate_user_by_api_key(api_key)
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
def get_verified_user(user=Depends(get_current_user)):
|
def get_verified_user(user=Depends(get_current_user)):
|
||||||
if user.role not in {"user", "admin"}:
|
if user.role not in {"user", "admin"}:
|
||||||
|
|
Loading…
Reference in a new issue