backend support api key

This commit is contained in:
liu.vaayne 2024-03-26 18:22:17 +08:00 committed by Vaayne
parent ac294a74e7
commit 81e928030f
4 changed files with 100 additions and 8 deletions

View file

@ -24,6 +24,7 @@ class Auth(Model):
email = CharField() email = CharField()
password = CharField() password = CharField()
active = BooleanField() active = BooleanField()
api_key = CharField(null=True, unique=True)
class Meta: class Meta:
database = DB database = DB
@ -34,6 +35,7 @@ class AuthModel(BaseModel):
email: str email: str
password: str password: str
active: bool = True active: bool = True
api_key: Optional[str] = None
#################### ####################
@ -45,6 +47,8 @@ class Token(BaseModel):
token: str token: str
token_type: str token_type: str
class ApiKey(BaseModel):
api_key: Optional[str] = None
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: str id: str
@ -122,6 +126,21 @@ class AuthsTable:
except: except:
return None return None
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}")
# if no api_key, return None
if not api_key:
return None
try:
auth = Auth.get(Auth.api_key == api_key, Auth.active == True)
if auth:
user = Users.get_user_by_id(auth.id)
return user
else:
return None
except:
return None
def update_user_password_by_id(self, id: str, new_password: str) -> bool: def update_user_password_by_id(self, id: str, new_password: str) -> bool:
try: try:
query = Auth.update(password=new_password).where(Auth.id == id) query = Auth.update(password=new_password).where(Auth.id == id)
@ -140,6 +159,22 @@ class AuthsTable:
except: except:
return False return False
def update_api_key_by_id(self, id: str, api_key: str) -> str:
try:
query = Auth.update(api_key=api_key).where(Auth.id == id)
result = query.execute()
return True if result == 1 else False
except:
return False
def get_api_key_by_id(self, id: str) -> Optional[str]:
try:
auth = Auth.get(Auth.id == id)
return auth.api_key
except:
return None
def delete_auth_by_id(self, id: str) -> bool: def delete_auth_by_id(self, id: str) -> bool:
try: try:
# Delete User # Delete User

View file

@ -1,12 +1,8 @@
from fastapi import Response, Request from fastapi import Request
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union
from fastapi import APIRouter, status from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import time
import uuid
import re import re
from apps.web.models.auths import ( from apps.web.models.auths import (
@ -17,6 +13,7 @@ from apps.web.models.auths import (
UserResponse, UserResponse,
SigninResponse, SigninResponse,
Auths, Auths,
ApiKey
) )
from apps.web.models.users import Users from apps.web.models.users import Users
@ -25,6 +22,7 @@ from utils.utils import (
get_current_user, get_current_user,
get_admin_user, get_admin_user,
create_token, create_token,
create_api_key
) )
from utils.misc import parse_duration, validate_email_format from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook from utils.webhook import post_webhook
@ -249,3 +247,40 @@ async def update_token_expires_duration(
return request.app.state.JWT_EXPIRES_IN return request.app.state.JWT_EXPIRES_IN
else: else:
return request.app.state.JWT_EXPIRES_IN return request.app.state.JWT_EXPIRES_IN
############################
# API Key
############################
# create api key
@router.post("/api_key", response_model=ApiKey)
async def create_api_key_(user=Depends(get_current_user)):
api_key = create_api_key()
success = Auths.update_api_key_by_id(user.id, api_key)
if success:
return {
"api_key": api_key,
}
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
# delete api key
@router.delete("/api_key", response_model=bool)
async def delete_api_key(user=Depends(get_current_user)):
success = Auths.update_api_key_by_id(user.id, None)
return success
# get api key
@router.get("/api_key", response_model=ApiKey)
async def get_api_key(user=Depends(get_current_user)):
api_key = Auths.get_api_key_by_id(user.id, None)
if api_key:
return {
"api_key": api_key,
}
else:
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

View file

@ -58,5 +58,6 @@ class ERROR_MESSAGES(str, Enum):
RATE_LIMIT_EXCEEDED = "API rate limit exceeded" RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found" OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."

View file

@ -1,6 +1,7 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends from fastapi import HTTPException, status, Depends
from apps.web.models.users import Users from apps.web.models.users import Users
from pydantic import BaseModel from pydantic import BaseModel
from typing import Union, Optional from typing import Union, Optional
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -8,6 +9,7 @@ from passlib.context import CryptContext
from datetime import datetime, timedelta from datetime import datetime, timedelta
import requests import requests
import jwt import jwt
import uuid
import logging import logging
import config import config
@ -58,6 +60,11 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :] return auth_header[len("Bearer ") :]
def create_api_key():
key = str(uuid.uuid4()).replace("-", "")
return f"sk-{key}"
def get_http_authorization_cred(auth_header: str): def get_http_authorization_cred(auth_header: str):
try: try:
scheme, credentials = auth_header.split(" ") scheme, credentials = auth_header.split(" ")
@ -69,6 +76,10 @@ def get_http_authorization_cred(auth_header: str):
def get_current_user( def get_current_user(
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
): ):
# auth by api key
if auth_token.credentials.startswith("sk-"):
return get_current_user_by_api_key(auth_token.credentials)
# auth by jwt token
data = decode_token(auth_token.credentials) data = decode_token(auth_token.credentials)
if data != None and "id" in data: if data != None and "id" in data:
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
@ -84,6 +95,16 @@ def get_current_user(
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
def get_current_user_by_api_key(api_key: str):
from apps.web.models.auths import Auths
user = Auths.authenticate_user_by_api_key(api_key)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
return user
def get_verified_user(user=Depends(get_current_user)): def get_verified_user(user=Depends(get_current_user)):
if user.role not in {"user", "admin"}: if user.role not in {"user", "admin"}: