Merge pull request #1306 from vaayne/feature/support_auth_by_api_key

Feature/support auth by api key
This commit is contained in:
Timothy Jaeryang Baek 2024-04-02 10:09:27 -07:00 committed by GitHub
commit fa61e738c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 469 additions and 71 deletions

View file

@ -0,0 +1,48 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"user", api_key=pw.CharField(max_length=255, null=True, unique=True)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("user", "api_key")

View file

@ -47,6 +47,10 @@ class Token(BaseModel):
token_type: str token_type: str
class ApiKey(BaseModel):
api_key: Optional[str] = None
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: str id: str
email: str email: str
@ -123,6 +127,18 @@ class AuthsTable:
except: except:
return None return None
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}")
# if no api_key, return None
if not api_key:
return None
try:
user = Users.get_user_by_api_key(api_key)
return user if user else None
except:
return False
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}") log.info(f"authenticate_user_by_trusted_header: {email}")
try: try:

View file

@ -20,6 +20,7 @@ class User(Model):
role = CharField() role = CharField()
profile_image_url = CharField() profile_image_url = CharField()
timestamp = DateField() timestamp = DateField()
api_key = CharField(null=True, unique=True)
class Meta: class Meta:
database = DB database = DB
@ -32,6 +33,7 @@ class UserModel(BaseModel):
role: str = "pending" role: str = "pending"
profile_image_url: str = "/user.png" profile_image_url: str = "/user.png"
timestamp: int # timestamp in epoch timestamp: int # timestamp in epoch
api_key: Optional[str] = None
#################### ####################
@ -82,6 +84,13 @@ class UsersTable:
except: except:
return None return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try:
user = User.get(User.api_key == api_key)
return UserModel(**model_to_dict(user))
except:
return None
def get_user_by_email(self, email: str) -> Optional[UserModel]: def get_user_by_email(self, email: str) -> Optional[UserModel]:
try: try:
user = User.get(User.email == email) user = User.get(User.email == email)
@ -149,5 +158,21 @@ class UsersTable:
except: except:
return False return False
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try:
query = User.update(api_key=api_key).where(User.id == id)
result = query.execute()
return True if result == 1 else False
except:
return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try:
user = User.get(User.id == id)
return user.api_key
except:
return None
Users = UsersTable(DB) Users = UsersTable(DB)

View file

@ -1,13 +1,10 @@
from fastapi import Response, Request from fastapi import Request
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union
from fastapi import APIRouter, status from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import time
import uuid
import re import re
import uuid
from apps.web.models.auths import ( from apps.web.models.auths import (
SigninForm, SigninForm,
@ -17,6 +14,7 @@ from apps.web.models.auths import (
UserResponse, UserResponse,
SigninResponse, SigninResponse,
Auths, Auths,
ApiKey,
) )
from apps.web.models.users import Users from apps.web.models.users import Users
@ -25,6 +23,7 @@ from utils.utils import (
get_current_user, get_current_user,
get_admin_user, get_admin_user,
create_token, create_token,
create_api_key,
) )
from utils.misc import parse_duration, validate_email_format from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook from utils.webhook import post_webhook
@ -267,3 +266,40 @@ async def update_token_expires_duration(
return request.app.state.JWT_EXPIRES_IN return request.app.state.JWT_EXPIRES_IN
else: else:
return request.app.state.JWT_EXPIRES_IN return request.app.state.JWT_EXPIRES_IN
############################
# API Key
############################
# create api key
@router.post("/api_key", response_model=ApiKey)
async def create_api_key_(user=Depends(get_current_user)):
api_key = create_api_key()
success = Users.update_user_api_key_by_id(user.id, api_key)
if success:
return {
"api_key": api_key,
}
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
# delete api key
@router.delete("/api_key", response_model=bool)
async def delete_api_key(user=Depends(get_current_user)):
success = Users.update_user_api_key_by_id(user.id, None)
return success
# get api key
@router.get("/api_key", response_model=ApiKey)
async def get_api_key(user=Depends(get_current_user)):
api_key = Users.get_user_api_key_by_id(user.id)
if api_key:
return {
"api_key": api_key,
}
else:
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

View file

@ -60,7 +60,8 @@ class ERROR_MESSAGES(str, Enum):
RATE_LIMIT_EXCEEDED = "API rate limit exceeded" RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found" OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."

View file

@ -1,6 +1,8 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends from fastapi import HTTPException, status, Depends
from apps.web.models.users import Users from apps.web.models.users import Users
from pydantic import BaseModel from pydantic import BaseModel
from typing import Union, Optional from typing import Union, Optional
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -8,6 +10,7 @@ from passlib.context import CryptContext
from datetime import datetime, timedelta from datetime import datetime, timedelta
import requests import requests
import jwt import jwt
import uuid
import logging import logging
import config import config
@ -58,6 +61,11 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :] return auth_header[len("Bearer ") :]
def create_api_key():
key = str(uuid.uuid4()).replace("-", "")
return f"sk-{key}"
def get_http_authorization_cred(auth_header: str): def get_http_authorization_cred(auth_header: str):
try: try:
scheme, credentials = auth_header.split(" ") scheme, credentials = auth_header.split(" ")
@ -69,6 +77,10 @@ def get_http_authorization_cred(auth_header: str):
def get_current_user( def get_current_user(
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
): ):
# auth by api key
if auth_token.credentials.startswith("sk-"):
return get_current_user_by_api_key(auth_token.credentials)
# auth by jwt token
data = decode_token(auth_token.credentials) data = decode_token(auth_token.credentials)
if data != None and "id" in data: if data != None and "id" in data:
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
@ -85,6 +97,16 @@ def get_current_user(
) )
def get_current_user_by_api_key(api_key: str):
user = Users.get_user_by_api_key(api_key)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
return user
def get_verified_user(user=Depends(get_current_user)): def get_verified_user(user=Depends(get_current_user)):
if user.role not in {"user", "admin"}: if user.role not in {"user", "admin"}:
raise HTTPException( raise HTTPException(

View file

@ -318,3 +318,78 @@ export const updateJWTExpiresDuration = async (token: string, duration: string)
return res; return res;
}; };
export const createAPIKey = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/auths/api_key`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res.api_key;
};
export const getAPIKey = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/auths/api_key`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res.api_key;
};
export const deleteAPIKey = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/auths/api_key`, {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};

View file

@ -3,11 +3,13 @@
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { user } from '$lib/stores'; import { user } from '$lib/stores';
import { updateUserProfile } from '$lib/apis/auths'; import { updateUserProfile, createAPIKey, getAPIKey } from '$lib/apis/auths';
import UpdatePassword from './Account/UpdatePassword.svelte'; import UpdatePassword from './Account/UpdatePassword.svelte';
import { getGravatarUrl } from '$lib/apis/utils'; import { getGravatarUrl } from '$lib/apis/utils';
import { copyToClipboard } from '$lib/utils'; import { copyToClipboard } from '$lib/utils';
import Plus from '$lib/components/icons/Plus.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -15,8 +17,14 @@
let profileImageUrl = ''; let profileImageUrl = '';
let name = ''; let name = '';
let showJWTToken = false; let showJWTToken = false;
let JWTTokenCopied = false; let JWTTokenCopied = false;
let APIKey = '';
let showAPIKey = false;
let APIKeyCopied = false;
let profileImageInputElement: HTMLInputElement; let profileImageInputElement: HTMLInputElement;
const submitHandler = async () => { const submitHandler = async () => {
@ -33,9 +41,23 @@
return false; return false;
}; };
onMount(() => { const createAPIKeyHandler = async () => {
APIKey = await createAPIKey(localStorage.token);
if (APIKey) {
toast.success($i18n.t('API Key created.'));
} else {
toast.error($i18n.t('Failed to create API Key.'));
}
};
onMount(async () => {
name = $user.name; name = $user.name;
profileImageUrl = $user.profile_image_url; profileImageUrl = $user.profile_image_url;
APIKey = await getAPIKey(localStorage.token).catch((error) => {
console.log(error);
return '';
});
}); });
</script> </script>
@ -170,22 +192,23 @@
<hr class=" dark:border-gray-700 my-4" /> <hr class=" dark:border-gray-700 my-4" />
<div class=" w-full justify-between"> <div class="flex flex-col gap-4">
<div class="flex w-full justify-between"> <div class="justify-between w-full">
<div class="flex justify-between w-full">
<div class="self-center text-xs font-medium">{$i18n.t('JWT Token')}</div> <div class="self-center text-xs font-medium">{$i18n.t('JWT Token')}</div>
</div> </div>
<div class="flex mt-2"> <div class="flex mt-2">
<div class="flex w-full"> <div class="flex w-full">
<input <input
class="w-full rounded-l-lg py-1.5 pl-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none" class="w-full rounded-l-lg py-1.5 pl-4 text-sm bg-white dark:text-gray-300 dark:bg-gray-800 outline-none"
type={showJWTToken ? 'text' : 'password'} type={showJWTToken ? 'text' : 'password'}
value={localStorage.token} value={localStorage.token}
disabled disabled
/> />
<button <button
class="dark:bg-gray-800 px-2 transition rounded-r-lg" class="px-2 transition rounded-r-lg bg-white dark:bg-gray-800"
on:click={() => { on:click={() => {
showJWTToken = !showJWTToken; showJWTToken = !showJWTToken;
}} }}
@ -225,7 +248,7 @@
</div> </div>
<button <button
class="ml-1.5 px-1.5 py-1 hover:bg-gray-800 transition rounded-lg" class="ml-1.5 px-1.5 py-1 dark:hover:bg-gray-800 transition rounded-lg"
on:click={() => { on:click={() => {
copyToClipboard(localStorage.token); copyToClipboard(localStorage.token);
JWTTokenCopied = true; JWTTokenCopied = true;
@ -269,6 +292,143 @@
</button> </button>
</div> </div>
</div> </div>
<div class="justify-between w-full">
<div class="flex justify-between w-full">
<div class="self-center text-xs font-medium">{$i18n.t('API Key')}</div>
</div>
<div class="flex mt-2">
{#if APIKey}
<div class="flex w-full">
<input
class="w-full rounded-l-lg py-1.5 pl-4 text-sm bg-white dark:text-gray-300 dark:bg-gray-800 outline-none"
type={showAPIKey ? 'text' : 'password'}
value={APIKey}
disabled
/>
<button
class="px-2 transition rounded-r-lg bg-white dark:bg-gray-800"
on:click={() => {
showAPIKey = !showAPIKey;
}}
>
{#if showAPIKey}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M3.28 2.22a.75.75 0 0 0-1.06 1.06l10.5 10.5a.75.75 0 1 0 1.06-1.06l-1.322-1.323a7.012 7.012 0 0 0 2.16-3.11.87.87 0 0 0 0-.567A7.003 7.003 0 0 0 4.82 3.76l-1.54-1.54Zm3.196 3.195 1.135 1.136A1.502 1.502 0 0 1 9.45 8.389l1.136 1.135a3 3 0 0 0-4.109-4.109Z"
clip-rule="evenodd"
/>
<path
d="m7.812 10.994 1.816 1.816A7.003 7.003 0 0 1 1.38 8.28a.87.87 0 0 1 0-.566 6.985 6.985 0 0 1 1.113-2.039l2.513 2.513a3 3 0 0 0 2.806 2.806Z"
/>
</svg>
{:else}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path d="M8 9.5a1.5 1.5 0 1 0 0-3 1.5 1.5 0 0 0 0 3Z" />
<path
fill-rule="evenodd"
d="M1.38 8.28a.87.87 0 0 1 0-.566 7.003 7.003 0 0 1 13.238.006.87.87 0 0 1 0 .566A7.003 7.003 0 0 1 1.379 8.28ZM11 8a3 3 0 1 1-6 0 3 3 0 0 1 6 0Z"
clip-rule="evenodd"
/>
</svg>
{/if}
</button>
</div>
<button
class="ml-1.5 px-1.5 py-1 dark:hover:bg-gray-800 transition rounded-lg"
on:click={() => {
copyToClipboard(APIKey);
APIKeyCopied = true;
setTimeout(() => {
APIKeyCopied = false;
}, 2000);
}}
>
{#if APIKeyCopied}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M16.704 4.153a.75.75 0 01.143 1.052l-8 10.5a.75.75 0 01-1.127.075l-4.5-4.5a.75.75 0 011.06-1.06l3.894 3.893 7.48-9.817a.75.75 0 011.05-.143z"
clip-rule="evenodd"
/>
</svg>
{:else}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M11.986 3H12a2 2 0 0 1 2 2v6a2 2 0 0 1-1.5 1.937V7A2.5 2.5 0 0 0 10 4.5H4.063A2 2 0 0 1 6 3h.014A2.25 2.25 0 0 1 8.25 1h1.5a2.25 2.25 0 0 1 2.236 2ZM10.5 4v-.75a.75.75 0 0 0-.75-.75h-1.5a.75.75 0 0 0-.75.75V4h3Z"
clip-rule="evenodd"
/>
<path
fill-rule="evenodd"
d="M3 6a1 1 0 0 0-1 1v7a1 1 0 0 0 1 1h7a1 1 0 0 0 1-1V7a1 1 0 0 0-1-1H3Zm1.75 2.5a.75.75 0 0 0 0 1.5h3.5a.75.75 0 0 0 0-1.5h-3.5ZM4 11.75a.75.75 0 0 1 .75-.75h3.5a.75.75 0 0 1 0 1.5h-3.5a.75.75 0 0 1-.75-.75Z"
clip-rule="evenodd"
/>
</svg>
{/if}
</button>
<Tooltip content="Create new key">
<button
class=" px-1.5 py-1 dark:hover:bg-gray-800transition rounded-lg"
on:click={() => {
createAPIKeyHandler();
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="2"
stroke="currentColor"
class="size-4"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M16.023 9.348h4.992v-.001M2.985 19.644v-4.992m0 0h4.992m-4.993 0 3.181 3.183a8.25 8.25 0 0 0 13.803-3.7M4.031 9.865a8.25 8.25 0 0 1 13.803-3.7l3.181 3.182m0-4.991v4.99"
/>
</svg>
</button>
</Tooltip>
{:else}
<button
class="flex gap-1.5 items-center font-medium px-3.5 py-1.5 rounded-lg bg-gray-100/70 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 transition"
on:click={() => {
createAPIKeyHandler();
}}
>
<Plus strokeWidth="2" className=" size-3.5" />
Create new secret key</button
>
{/if}
</div>
</div>
</div>
</div> </div>
<div class="flex justify-end pt-3 text-sm font-medium"> <div class="flex justify-end pt-3 text-sm font-medium">

View file

@ -0,0 +1,15 @@
<script lang="ts">
export let className = 'w-4 h-4';
export let strokeWidth = '1.5';
</script>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width={strokeWidth}
stroke="currentColor"
class={className}
>
<path stroke-linecap="round" stroke-linejoin="round" d="M12 4.5v15m7.5-7.5h-15" />
</svg>