forked from open-webui/open-webui
		
	backend support api key
This commit is contained in:
		
							parent
							
								
									ac294a74e7
								
							
						
					
					
						commit
						81e928030f
					
				
					 4 changed files with 100 additions and 8 deletions
				
			
		|  | @ -24,6 +24,7 @@ class Auth(Model): | ||||||
|     email = CharField() |     email = CharField() | ||||||
|     password = CharField() |     password = CharField() | ||||||
|     active = BooleanField() |     active = BooleanField() | ||||||
|  |     api_key = CharField(null=True, unique=True) | ||||||
| 
 | 
 | ||||||
|     class Meta: |     class Meta: | ||||||
|         database = DB |         database = DB | ||||||
|  | @ -34,6 +35,7 @@ class AuthModel(BaseModel): | ||||||
|     email: str |     email: str | ||||||
|     password: str |     password: str | ||||||
|     active: bool = True |     active: bool = True | ||||||
|  |     api_key: Optional[str] = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| #################### | #################### | ||||||
|  | @ -45,6 +47,8 @@ class Token(BaseModel): | ||||||
|     token: str |     token: str | ||||||
|     token_type: str |     token_type: str | ||||||
| 
 | 
 | ||||||
|  | class ApiKey(BaseModel): | ||||||
|  |     api_key: Optional[str] = None | ||||||
| 
 | 
 | ||||||
| class UserResponse(BaseModel): | class UserResponse(BaseModel): | ||||||
|     id: str |     id: str | ||||||
|  | @ -122,6 +126,21 @@ class AuthsTable: | ||||||
|         except: |         except: | ||||||
|             return None |             return None | ||||||
| 
 | 
 | ||||||
|  |     def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: | ||||||
|  |         log.info(f"authenticate_user_by_api_key: {api_key}") | ||||||
|  |         # if no api_key, return None | ||||||
|  |         if not api_key: | ||||||
|  |             return None | ||||||
|  |         try: | ||||||
|  |             auth = Auth.get(Auth.api_key == api_key, Auth.active == True) | ||||||
|  |             if auth: | ||||||
|  |                 user = Users.get_user_by_id(auth.id) | ||||||
|  |                 return user | ||||||
|  |             else: | ||||||
|  |                 return None | ||||||
|  |         except: | ||||||
|  |             return None | ||||||
|  | 
 | ||||||
|     def update_user_password_by_id(self, id: str, new_password: str) -> bool: |     def update_user_password_by_id(self, id: str, new_password: str) -> bool: | ||||||
|         try: |         try: | ||||||
|             query = Auth.update(password=new_password).where(Auth.id == id) |             query = Auth.update(password=new_password).where(Auth.id == id) | ||||||
|  | @ -140,6 +159,22 @@ class AuthsTable: | ||||||
|         except: |         except: | ||||||
|             return False |             return False | ||||||
| 
 | 
 | ||||||
|  |     def update_api_key_by_id(self, id: str, api_key: str) -> str: | ||||||
|  |         try: | ||||||
|  |             query = Auth.update(api_key=api_key).where(Auth.id == id) | ||||||
|  |             result = query.execute() | ||||||
|  | 
 | ||||||
|  |             return True if result == 1 else False | ||||||
|  |         except: | ||||||
|  |             return False | ||||||
|  | 
 | ||||||
|  |     def get_api_key_by_id(self, id: str) -> Optional[str]: | ||||||
|  |         try: | ||||||
|  |             auth = Auth.get(Auth.id == id) | ||||||
|  |             return auth.api_key | ||||||
|  |         except: | ||||||
|  |             return None | ||||||
|  | 
 | ||||||
|     def delete_auth_by_id(self, id: str) -> bool: |     def delete_auth_by_id(self, id: str) -> bool: | ||||||
|         try: |         try: | ||||||
|             # Delete User |             # Delete User | ||||||
|  |  | ||||||
|  | @ -1,12 +1,8 @@ | ||||||
| from fastapi import Response, Request | from fastapi import Request | ||||||
| from fastapi import Depends, FastAPI, HTTPException, status | from fastapi import Depends, HTTPException, status | ||||||
| from datetime import datetime, timedelta |  | ||||||
| from typing import List, Union |  | ||||||
| 
 | 
 | ||||||
| from fastapi import APIRouter, status | from fastapi import APIRouter | ||||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||||
| import time |  | ||||||
| import uuid |  | ||||||
| import re | import re | ||||||
| 
 | 
 | ||||||
| from apps.web.models.auths import ( | from apps.web.models.auths import ( | ||||||
|  | @ -17,6 +13,7 @@ from apps.web.models.auths import ( | ||||||
|     UserResponse, |     UserResponse, | ||||||
|     SigninResponse, |     SigninResponse, | ||||||
|     Auths, |     Auths, | ||||||
|  |     ApiKey | ||||||
| ) | ) | ||||||
| from apps.web.models.users import Users | from apps.web.models.users import Users | ||||||
| 
 | 
 | ||||||
|  | @ -25,6 +22,7 @@ from utils.utils import ( | ||||||
|     get_current_user, |     get_current_user, | ||||||
|     get_admin_user, |     get_admin_user, | ||||||
|     create_token, |     create_token, | ||||||
|  |     create_api_key | ||||||
| ) | ) | ||||||
| from utils.misc import parse_duration, validate_email_format | from utils.misc import parse_duration, validate_email_format | ||||||
| from utils.webhook import post_webhook | from utils.webhook import post_webhook | ||||||
|  | @ -249,3 +247,40 @@ async def update_token_expires_duration( | ||||||
|         return request.app.state.JWT_EXPIRES_IN |         return request.app.state.JWT_EXPIRES_IN | ||||||
|     else: |     else: | ||||||
|         return request.app.state.JWT_EXPIRES_IN |         return request.app.state.JWT_EXPIRES_IN | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ############################ | ||||||
|  | # API Key | ||||||
|  | ############################ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # create api key | ||||||
|  | @router.post("/api_key", response_model=ApiKey) | ||||||
|  | async def create_api_key_(user=Depends(get_current_user)): | ||||||
|  |     api_key = create_api_key() | ||||||
|  |     success = Auths.update_api_key_by_id(user.id, api_key) | ||||||
|  |     if success: | ||||||
|  |         return { | ||||||
|  |             "api_key": api_key, | ||||||
|  |         } | ||||||
|  |     else: | ||||||
|  |         raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # delete api key | ||||||
|  | @router.delete("/api_key", response_model=bool) | ||||||
|  | async def delete_api_key(user=Depends(get_current_user)): | ||||||
|  |     success = Auths.update_api_key_by_id(user.id, None) | ||||||
|  |     return success | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # get api key | ||||||
|  | @router.get("/api_key", response_model=ApiKey) | ||||||
|  | async def get_api_key(user=Depends(get_current_user)): | ||||||
|  |     api_key = Auths.get_api_key_by_id(user.id, None) | ||||||
|  |     if api_key: | ||||||
|  |         return { | ||||||
|  |             "api_key": api_key, | ||||||
|  |         } | ||||||
|  |     else: | ||||||
|  |         raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) | ||||||
|  |  | ||||||
|  | @ -58,5 +58,6 @@ class ERROR_MESSAGES(str, Enum): | ||||||
|     RATE_LIMIT_EXCEEDED = "API rate limit exceeded" |     RATE_LIMIT_EXCEEDED = "API rate limit exceeded" | ||||||
| 
 | 
 | ||||||
|     MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" |     MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" | ||||||
|     OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found" |     OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found" | ||||||
|     OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" |     OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" | ||||||
|  |     CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance." | ||||||
|  |  | ||||||
|  | @ -1,6 +1,7 @@ | ||||||
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | ||||||
| from fastapi import HTTPException, status, Depends | from fastapi import HTTPException, status, Depends | ||||||
| from apps.web.models.users import Users | from apps.web.models.users import Users | ||||||
|  | 
 | ||||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||||
| from typing import Union, Optional | from typing import Union, Optional | ||||||
| from constants import ERROR_MESSAGES | from constants import ERROR_MESSAGES | ||||||
|  | @ -8,6 +9,7 @@ from passlib.context import CryptContext | ||||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||||
| import requests | import requests | ||||||
| import jwt | import jwt | ||||||
|  | import uuid | ||||||
| import logging | import logging | ||||||
| import config | import config | ||||||
| 
 | 
 | ||||||
|  | @ -58,6 +60,11 @@ def extract_token_from_auth_header(auth_header: str): | ||||||
|     return auth_header[len("Bearer ") :] |     return auth_header[len("Bearer ") :] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def create_api_key(): | ||||||
|  |     key = str(uuid.uuid4()).replace("-", "") | ||||||
|  |     return f"sk-{key}" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def get_http_authorization_cred(auth_header: str): | def get_http_authorization_cred(auth_header: str): | ||||||
|     try: |     try: | ||||||
|         scheme, credentials = auth_header.split(" ") |         scheme, credentials = auth_header.split(" ") | ||||||
|  | @ -69,6 +76,10 @@ def get_http_authorization_cred(auth_header: str): | ||||||
| def get_current_user( | def get_current_user( | ||||||
|     auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), |     auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), | ||||||
| ): | ): | ||||||
|  |     # auth by api key | ||||||
|  |     if auth_token.credentials.startswith("sk-"): | ||||||
|  |         return get_current_user_by_api_key(auth_token.credentials) | ||||||
|  |     # auth by jwt token | ||||||
|     data = decode_token(auth_token.credentials) |     data = decode_token(auth_token.credentials) | ||||||
|     if data != None and "id" in data: |     if data != None and "id" in data: | ||||||
|         user = Users.get_user_by_id(data["id"]) |         user = Users.get_user_by_id(data["id"]) | ||||||
|  | @ -84,6 +95,16 @@ def get_current_user( | ||||||
|             detail=ERROR_MESSAGES.UNAUTHORIZED, |             detail=ERROR_MESSAGES.UNAUTHORIZED, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  | def get_current_user_by_api_key(api_key: str): | ||||||
|  |     from apps.web.models.auths import Auths | ||||||
|  | 
 | ||||||
|  |     user = Auths.authenticate_user_by_api_key(api_key) | ||||||
|  |     if user is None: | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=status.HTTP_401_UNAUTHORIZED, | ||||||
|  |             detail=ERROR_MESSAGES.INVALID_TOKEN, | ||||||
|  |         ) | ||||||
|  |     return user | ||||||
| 
 | 
 | ||||||
| def get_verified_user(user=Depends(get_current_user)): | def get_verified_user(user=Depends(get_current_user)): | ||||||
|     if user.role not in {"user", "admin"}: |     if user.role not in {"user", "admin"}: | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 liu.vaayne
						liu.vaayne