Merge pull request #624 from explorigin/session-security

Improve Session Security
This commit is contained in:
Timothy Jaeryang Baek 2024-02-03 17:41:31 -08:00 committed by GitHub
commit 323ec3787e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 34 additions and 23 deletions

View file

@ -25,7 +25,7 @@ ENV OLLAMA_API_BASE_URL "/ollama/api"
ENV OPENAI_API_BASE_URL "" ENV OPENAI_API_BASE_URL ""
ENV OPENAI_API_KEY "" ENV OPENAI_API_KEY ""
ENV WEBUI_JWT_SECRET_KEY "SECRET_KEY" ENV WEBUI_SECRET_KEY ""
WORKDIR /app/backend WORKDIR /app/backend

View file

@ -5,12 +5,7 @@ import uuid
from peewee import * from peewee import *
from apps.web.models.users import UserModel, Users from apps.web.models.users import UserModel, Users
from utils.utils import ( from utils.utils import verify_password
verify_password,
get_password_hash,
bearer_scheme,
create_token,
)
from apps.web.internal.db import DB from apps.web.internal.db import DB

View file

@ -93,7 +93,7 @@ async def update_password(
async def signin(form_data: SigninForm): async def signin(form_data: SigninForm):
user = Auths.authenticate_user(form_data.email.lower(), form_data.password) user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if user: if user:
token = create_token(data={"email": user.email}) token = create_token(data={"id": user.id})
return { return {
"token": token, "token": token,
@ -132,7 +132,7 @@ async def signup(request: Request, form_data: SignupForm):
) )
if user: if user:
token = create_token(data={"email": user.email}) token = create_token(data={"id": user.id})
# response.set_cookie(key='token', value=token, httponly=True) # response.set_cookie(key='token', value=token, httponly=True)
return { return {

View file

@ -25,9 +25,6 @@ from apps.web.models.tags import (
Tags, Tags,
) )
from utils.utils import (
bearer_scheme,
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()

View file

@ -98,12 +98,15 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.61")
WEBUI_AUTH = True WEBUI_AUTH = True
#################################### ####################################
# WEBUI_JWT_SECRET_KEY # WEBUI_SECRET_KEY
#################################### ####################################
WEBUI_JWT_SECRET_KEY = os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t") WEBUI_SECRET_KEY = os.environ.get(
"WEBUI_SECRET_KEY",
os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t") # DEPRECATED: remove at next major version
)
if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "": if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
#################################### ####################################

View file

@ -3,5 +3,20 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cd "$SCRIPT_DIR" || exit cd "$SCRIPT_DIR" || exit
KEY_FILE=.webui_secret_key
PORT="${PORT:-8080}" PORT="${PORT:-8080}"
exec uvicorn main:app --host 0.0.0.0 --port "$PORT" --forwarded-allow-ips '*' if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
echo No WEBUI_SECRET_KEY provided
if ! [ -e "$KEY_FILE" ]; then
echo Generating WEBUI_SECRET_KEY
# Generate a random value to use as a WEBUI_SECRET_KEY in case the user didn't provide one.
echo $(head -c 12 /dev/random | base64) > $KEY_FILE
fi
echo Loading WEBUI_SECRET_KEY from $KEY_FILE
WEBUI_SECRET_KEY=`cat $KEY_FILE`
fi
WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn main:app --host 0.0.0.0 --port "$PORT" --forwarded-allow-ips '*'

View file

@ -14,14 +14,14 @@ import config
logging.getLogger("passlib").setLevel(logging.ERROR) logging.getLogger("passlib").setLevel(logging.ERROR)
JWT_SECRET_KEY = config.WEBUI_JWT_SECRET_KEY SESSION_SECRET = config.WEBUI_SECRET_KEY
ALGORITHM = "HS256" ALGORITHM = "HS256"
############## ##############
# Auth Utils # Auth Utils
############## ##############
bearer_scheme = HTTPBearer() bearer_security = HTTPBearer()
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@ -42,13 +42,13 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st
expire = datetime.utcnow() + expires_delta expire = datetime.utcnow() + expires_delta
payload.update({"exp": expire}) payload.update({"exp": expire})
encoded_jwt = jwt.encode(payload, JWT_SECRET_KEY, algorithm=ALGORITHM) encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt
def decode_token(token: str) -> Optional[dict]: def decode_token(token: str) -> Optional[dict]:
try: try:
decoded = jwt.decode(token, JWT_SECRET_KEY, options={"verify_signature": False}) decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
return decoded return decoded
except Exception as e: except Exception as e:
return None return None
@ -58,10 +58,10 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :] return auth_header[len("Bearer ") :]
def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(bearer_security)):
data = decode_token(auth_token.credentials) data = decode_token(auth_token.credentials)
if data != None and "email" in data: if data != None and "id" in data:
user = Users.get_user_by_email(data["email"]) user = Users.get_user_by_id(data["id"])
if user is None: if user is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,

View file

@ -26,6 +26,7 @@ services:
- ${OLLAMA_WEBUI_PORT-3000}:8080 - ${OLLAMA_WEBUI_PORT-3000}:8080
environment: environment:
- 'OLLAMA_API_BASE_URL=http://ollama:11434/api' - 'OLLAMA_API_BASE_URL=http://ollama:11434/api'
- 'WEBUI_SECRET_KEY='
extra_hosts: extra_hosts:
- host.docker.internal:host-gateway - host.docker.internal:host-gateway
restart: unless-stopped restart: unless-stopped