forked from open-webui/open-webui
Merge pull request #624 from explorigin/session-security
Improve Session Security
This commit is contained in:
commit
323ec3787e
8 changed files with 34 additions and 23 deletions
|
@ -25,7 +25,7 @@ ENV OLLAMA_API_BASE_URL "/ollama/api"
|
||||||
ENV OPENAI_API_BASE_URL ""
|
ENV OPENAI_API_BASE_URL ""
|
||||||
ENV OPENAI_API_KEY ""
|
ENV OPENAI_API_KEY ""
|
||||||
|
|
||||||
ENV WEBUI_JWT_SECRET_KEY "SECRET_KEY"
|
ENV WEBUI_SECRET_KEY ""
|
||||||
|
|
||||||
WORKDIR /app/backend
|
WORKDIR /app/backend
|
||||||
|
|
||||||
|
|
|
@ -5,12 +5,7 @@ import uuid
|
||||||
from peewee import *
|
from peewee import *
|
||||||
|
|
||||||
from apps.web.models.users import UserModel, Users
|
from apps.web.models.users import UserModel, Users
|
||||||
from utils.utils import (
|
from utils.utils import verify_password
|
||||||
verify_password,
|
|
||||||
get_password_hash,
|
|
||||||
bearer_scheme,
|
|
||||||
create_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
from apps.web.internal.db import DB
|
from apps.web.internal.db import DB
|
||||||
|
|
||||||
|
|
|
@ -93,7 +93,7 @@ async def update_password(
|
||||||
async def signin(form_data: SigninForm):
|
async def signin(form_data: SigninForm):
|
||||||
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
|
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
|
||||||
if user:
|
if user:
|
||||||
token = create_token(data={"email": user.email})
|
token = create_token(data={"id": user.id})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"token": token,
|
"token": token,
|
||||||
|
@ -132,7 +132,7 @@ async def signup(request: Request, form_data: SignupForm):
|
||||||
)
|
)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
token = create_token(data={"email": user.email})
|
token = create_token(data={"id": user.id})
|
||||||
# response.set_cookie(key='token', value=token, httponly=True)
|
# response.set_cookie(key='token', value=token, httponly=True)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -25,9 +25,6 @@ from apps.web.models.tags import (
|
||||||
Tags,
|
Tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils.utils import (
|
|
||||||
bearer_scheme,
|
|
||||||
)
|
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
|
@ -98,12 +98,15 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.61")
|
||||||
WEBUI_AUTH = True
|
WEBUI_AUTH = True
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# WEBUI_JWT_SECRET_KEY
|
# WEBUI_SECRET_KEY
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
WEBUI_JWT_SECRET_KEY = os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t")
|
WEBUI_SECRET_KEY = os.environ.get(
|
||||||
|
"WEBUI_SECRET_KEY",
|
||||||
|
os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t") # DEPRECATED: remove at next major version
|
||||||
|
)
|
||||||
|
|
||||||
if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "":
|
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
|
||||||
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
|
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
|
|
|
@ -3,5 +3,20 @@
|
||||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||||
cd "$SCRIPT_DIR" || exit
|
cd "$SCRIPT_DIR" || exit
|
||||||
|
|
||||||
|
KEY_FILE=.webui_secret_key
|
||||||
|
|
||||||
PORT="${PORT:-8080}"
|
PORT="${PORT:-8080}"
|
||||||
exec uvicorn main:app --host 0.0.0.0 --port "$PORT" --forwarded-allow-ips '*'
|
if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
|
||||||
|
echo No WEBUI_SECRET_KEY provided
|
||||||
|
|
||||||
|
if ! [ -e "$KEY_FILE" ]; then
|
||||||
|
echo Generating WEBUI_SECRET_KEY
|
||||||
|
# Generate a random value to use as a WEBUI_SECRET_KEY in case the user didn't provide one.
|
||||||
|
echo $(head -c 12 /dev/random | base64) > $KEY_FILE
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo Loading WEBUI_SECRET_KEY from $KEY_FILE
|
||||||
|
WEBUI_SECRET_KEY=`cat $KEY_FILE`
|
||||||
|
fi
|
||||||
|
|
||||||
|
WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn main:app --host 0.0.0.0 --port "$PORT" --forwarded-allow-ips '*'
|
||||||
|
|
|
@ -14,14 +14,14 @@ import config
|
||||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
logging.getLogger("passlib").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
|
||||||
JWT_SECRET_KEY = config.WEBUI_JWT_SECRET_KEY
|
SESSION_SECRET = config.WEBUI_SECRET_KEY
|
||||||
ALGORITHM = "HS256"
|
ALGORITHM = "HS256"
|
||||||
|
|
||||||
##############
|
##############
|
||||||
# Auth Utils
|
# Auth Utils
|
||||||
##############
|
##############
|
||||||
|
|
||||||
bearer_scheme = HTTPBearer()
|
bearer_security = HTTPBearer()
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,13 +42,13 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st
|
||||||
expire = datetime.utcnow() + expires_delta
|
expire = datetime.utcnow() + expires_delta
|
||||||
payload.update({"exp": expire})
|
payload.update({"exp": expire})
|
||||||
|
|
||||||
encoded_jwt = jwt.encode(payload, JWT_SECRET_KEY, algorithm=ALGORITHM)
|
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
def decode_token(token: str) -> Optional[dict]:
|
def decode_token(token: str) -> Optional[dict]:
|
||||||
try:
|
try:
|
||||||
decoded = jwt.decode(token, JWT_SECRET_KEY, options={"verify_signature": False})
|
decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
|
||||||
return decoded
|
return decoded
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return None
|
return None
|
||||||
|
@ -58,10 +58,10 @@ def extract_token_from_auth_header(auth_header: str):
|
||||||
return auth_header[len("Bearer ") :]
|
return auth_header[len("Bearer ") :]
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
|
def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(bearer_security)):
|
||||||
data = decode_token(auth_token.credentials)
|
data = decode_token(auth_token.credentials)
|
||||||
if data != None and "email" in data:
|
if data != None and "id" in data:
|
||||||
user = Users.get_user_by_email(data["email"])
|
user = Users.get_user_by_id(data["id"])
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
|
@ -26,6 +26,7 @@ services:
|
||||||
- ${OLLAMA_WEBUI_PORT-3000}:8080
|
- ${OLLAMA_WEBUI_PORT-3000}:8080
|
||||||
environment:
|
environment:
|
||||||
- 'OLLAMA_API_BASE_URL=http://ollama:11434/api'
|
- 'OLLAMA_API_BASE_URL=http://ollama:11434/api'
|
||||||
|
- 'WEBUI_SECRET_KEY='
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- host.docker.internal:host-gateway
|
- host.docker.internal:host-gateway
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
Loading…
Reference in a new issue