Compare commits

..

3 commits

Author SHA1 Message Date
Timothy Jaeryang Baek
59de980306
Update main.py 2024-02-23 04:51:41 -05:00
Timothy Jaeryang Baek
837feb4e79
Merge branch 'main' into functions 2024-02-23 04:51:01 -05:00
Timothy J. Baek
6caa7750bb feat: function plugins support 2024-02-19 01:58:14 -08:00
253 changed files with 8945 additions and 36744 deletions

View file

@ -7,6 +7,7 @@ node_modules
/package /package
.env .env
.env.* .env.*
!.env.example
vite.config.js.timestamp-* vite.config.js.timestamp-*
vite.config.ts.timestamp-* vite.config.ts.timestamp-*
__pycache__ __pycache__

View file

@ -1,6 +1,6 @@
# Ollama URL for the backend to connect # Ollama URL for the backend to connect
# The path '/ollama' will be redirected to the specified backend URL # The path '/ollama/api' will be redirected to the specified backend URL
OLLAMA_BASE_URL='http://localhost:11434' OLLAMA_API_BASE_URL='http://localhost:11434/api'
OPENAI_API_BASE_URL='' OPENAI_API_BASE_URL=''
OPENAI_API_KEY='' OPENAI_API_KEY=''
@ -9,9 +9,4 @@ OPENAI_API_KEY=''
# DO NOT TRACK # DO NOT TRACK
SCARF_NO_ANALYTICS=true SCARF_NO_ANALYTICS=true
DO_NOT_TRACK=true DO_NOT_TRACK=true
ANONYMIZED_TELEMETRY=false
# Use locally bundled version of the LiteLLM cost map json
# to avoid repetitive startup connections
LITELLM_LOCAL_MODEL_COST_MAP="True"

View file

@ -4,7 +4,6 @@ module.exports = {
'eslint:recommended', 'eslint:recommended',
'plugin:@typescript-eslint/recommended', 'plugin:@typescript-eslint/recommended',
'plugin:svelte/recommended', 'plugin:svelte/recommended',
'plugin:cypress/recommended',
'prettier' 'prettier'
], ],
parser: '@typescript-eslint/parser', parser: '@typescript-eslint/parser',

View file

@ -24,9 +24,6 @@ assignees: ''
## Environment ## Environment
- **Open WebUI Version:** [e.g., 0.1.120]
- **Ollama (if applicable):** [e.g., 0.1.30, 0.1.32-rc1]
- **Operating System:** [e.g., Windows 10, macOS Big Sur, Ubuntu 20.04] - **Operating System:** [e.g., Windows 10, macOS Big Sur, Ubuntu 20.04]
- **Browser (if applicable):** [e.g., Chrome 100.0, Firefox 98.0] - **Browser (if applicable):** [e.g., Chrome 100.0, Firefox 98.0]
@ -35,7 +32,7 @@ assignees: ''
**Confirmation:** **Confirmation:**
- [ ] I have read and followed all the instructions provided in the README.md. - [ ] I have read and followed all the instructions provided in the README.md.
- [ ] I am on the latest version of both Open WebUI and Ollama. - [ ] I have reviewed the troubleshooting.md document.
- [ ] I have included the browser console logs. - [ ] I have included the browser console logs.
- [ ] I have included the Docker container logs. - [ ] I have included the Docker container logs.

View file

@ -1,11 +0,0 @@
version: 2
updates:
- package-ecosystem: pip
directory: "/backend"
schedule:
interval: daily
time: "13:00"
groups:
python-packages:
patterns:
- "*"

View file

@ -1,50 +0,0 @@
## Pull Request Checklist
- [ ] **Description:** Briefly describe the changes in this pull request.
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
- [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
- [ ] **Testing:** Have you written and run sufficient tests for the changes?
- [ ] **Code Review:** Have you self-reviewed your code and addressed any coding standard issues?
---
## Description
[Insert a brief description of the changes made in this pull request, including any relevant motivation and impact.]
---
### Changelog Entry
### Added
- [List any new features, functionalities, or additions]
### Fixed
- [List any fixes, corrections, or bug fixes]
### Changed
- [List any changes, updates, refactorings, or optimizations]
### Removed
- [List any removed features, files, or deprecated functionalities]
### Security
- [List any new or updated security-related changes, including vulnerability fixes]
### Breaking Changes
- [List any breaking changes affecting compatibility or functionality]
---
### Additional Information
- [Insert any additional context, notes, or explanations for the changes]
- [Reference any related issues, commits, or other relevant information]

View file

@ -19,34 +19,24 @@ jobs:
echo "No changes to package.json" echo "No changes to package.json"
exit 1 exit 1
} }
- name: Get version number from package.json - name: Get version number from package.json
id: get_version id: get_version
run: | run: |
VERSION=$(jq -r '.version' package.json) VERSION=$(jq -r '.version' package.json)
echo "::set-output name=version::$VERSION" echo "::set-output name=version::$VERSION"
- name: Extract latest CHANGELOG entry
id: changelog
run: |
CHANGELOG_CONTENT=$(awk 'BEGIN {print_section=0;} /^## \[/ {if (print_section == 0) {print_section=1;} else {exit;}} print_section {print;}' CHANGELOG.md)
CHANGELOG_ESCAPED=$(echo "$CHANGELOG_CONTENT" | sed ':a;N;$!ba;s/\n/%0A/g')
echo "Extracted latest release notes from CHANGELOG.md:"
echo -e "$CHANGELOG_CONTENT"
echo "::set-output name=content::$CHANGELOG_ESCAPED"
- name: Create GitHub release - name: Create GitHub release
uses: actions/github-script@v5 uses: actions/github-script@v5
with: with:
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
script: | script: |
const changelog = `${{ steps.changelog.outputs.content }}`;
const release = await github.rest.repos.createRelease({ const release = await github.rest.repos.createRelease({
owner: context.repo.owner, owner: context.repo.owner,
repo: context.repo.repo, repo: context.repo.repo,
tag_name: `v${{ steps.get_version.outputs.version }}`, tag_name: `v${{ steps.get_version.outputs.version }}`,
name: `v${{ steps.get_version.outputs.version }}`, name: `v${{ steps.get_version.outputs.version }}`,
body: changelog, body: 'Automatically created new release',
}) })
console.log(`Created release ${release.data.html_url}`) console.log(`Created release ${release.data.html_url}`)
@ -57,14 +47,3 @@ jobs:
path: . path: .
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Trigger Docker build workflow
uses: actions/github-script@v7
with:
script: |
github.rest.actions.createWorkflowDispatch({
owner: context.repo.owner,
repo: context.repo.repo,
workflow_id: 'docker-build.yaml',
ref: 'v${{ steps.get_version.outputs.version }}',
})

View file

@ -1,7 +1,8 @@
name: Create and publish Docker images with specific build args #
name: Create and publish a Docker image
# Configures this workflow to run every time a change is pushed to the branch called `release`.
on: on:
workflow_dispatch:
push: push:
branches: branches:
- main - main
@ -9,380 +10,56 @@ on:
tags: tags:
- v* - v*
# Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
env: env:
REGISTRY: git.depeuter.dev REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }} IMAGE_NAME: ${{ github.repository }}
RUNNER_TOOL_CACHE: /toolcache
FULL_IMAGE_NAME: ${{ env.REGISTRY }}/${{ github.repository }}
# There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
jobs: jobs:
build-main-image: build-and-push-image:
runs-on: ubuntu-latest runs-on: ubuntu-latest
container:
image: catthehacker/ubuntu:act-latest
# Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job. # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
permissions: permissions:
contents: read contents: read
packages: write packages: write
strategy: #
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
steps: steps:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
# Required for multi architecture build
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v3 uses: docker/setup-qemu-action@v3
# Required for multi architecture build
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
- name: Log in to the Container registry - name: Log in to the Container registry
uses: docker/login-action@v3 uses: docker/login-action@v3
with: with:
registry: ${{ env.REGISTRY }} registry: ${{ env.REGISTRY }}
username: ${{ github.actor }} username: ${{ github.actor }}
password: ${{ secrets.CI_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata for Docker images (default latest tag) - name: Extract metadata for Docker images
id: meta id: meta
uses: docker/metadata-action@v5 uses: docker/metadata-action@v5
with: with:
images: ${{ env.FULL_IMAGE_NAME }} images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
# This configuration dynamically generates tags based on the branch, tag, commit, and custom suffix for lite version.
tags: | tags: |
type=ref,event=branch type=ref,event=branch
type=ref,event=tag type=ref,event=tag
type=sha,prefix=git- type=sha,prefix=git-
type=semver,pattern={{version}} type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
flavor: | flavor: |
latest=${{ github.ref == 'refs/heads/main' }} latest=${{ github.ref == 'refs/heads/main' }}
- name: Build Docker image (latest) - name: Build and push Docker image
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
id: build
with: with:
context: . context: .
push: true push: true
platforms: ${{ matrix.platform }} platforms: linux/amd64,linux/arm64
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: digests-main-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
build-cuda-image:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
steps:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.CI_TOKEN }}
- name: Extract metadata for Docker images (default latest tag)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
suffix=-cuda,onlatest=true
- name: Build Docker image (cuda)
uses: docker/build-push-action@v5
id: build
with:
context: .
push: true
platforms: ${{ matrix.platform }}
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha
cache-to: type=gha,mode=max
build-args: USE_CUDA=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: digests-cuda-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
build-ollama-image:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
steps:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.CI_TOKEN }}
- name: Extract metadata for Docker images (ollama tag)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=ollama
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
suffix=-ollama,onlatest=true
- name: Build Docker image (ollama)
uses: docker/build-push-action@v5
id: build
with:
context: .
push: true
platforms: ${{ matrix.platform }}
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha
cache-to: type=gha,mode=max
build-args: USE_OLLAMA=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: digests-ollama-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
merge-main-images:
runs-on: ubuntu-latest
needs: [ build-main-image ]
steps:
- name: Download digests
uses: actions/download-artifact@v4
with:
pattern: digests-main-*
path: /tmp/digests
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.CI_TOKEN }}
- name: Extract metadata for Docker images (default latest tag)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
merge-cuda-images:
runs-on: ubuntu-latest
needs: [ build-cuda-image ]
steps:
- name: Download digests
uses: actions/download-artifact@v4
with:
pattern: digests-cuda-*
path: /tmp/digests
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.CI_TOKEN }}
- name: Extract metadata for Docker images (default latest tag)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
suffix=-cuda,onlatest=true
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
merge-ollama-images:
runs-on: ubuntu-latest
needs: [ build-ollama-image ]
steps:
- name: Download digests
uses: actions/download-artifact@v4
with:
pattern: digests-ollama-*
path: /tmp/digests
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.CI_TOKEN }}
- name: Extract metadata for Docker images (default ollama tag)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=ollama
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
suffix=-ollama,onlatest=true
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}

View file

@ -1,39 +1,27 @@
name: Python CI name: Python CI
on: on:
push: push:
branches: branches: ['main']
- main
- dev
pull_request: pull_request:
branches:
- main
- dev
jobs: jobs:
build: build:
name: 'Format Backend' name: 'Format Backend'
env:
PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: [3.11] node-version:
- latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Use Python
- name: Set up Python uses: actions/setup-python@v4
uses: actions/setup-python@v2 - name: Use Bun
with: uses: oven-sh/setup-bun@v1
python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install black pip install yapf
- name: Format backend - name: Format backend
run: npm run format:backend run: bun run format:backend
- name: Check for changes after format
run: git diff --exit-code

View file

@ -1,39 +1,22 @@
name: Frontend Build name: Bun CI
on: on:
push: push:
branches: branches: ['main']
- main
- dev
pull_request: pull_request:
branches:
- main
- dev
jobs: jobs:
build: build:
name: 'Format & Build Frontend' name: 'Format & Build Frontend'
env:
PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Repository - uses: actions/checkout@v4
uses: actions/checkout@v4 - name: Use Bun
uses: oven-sh/setup-bun@v1
- name: Setup Node.js - run: bun --version
uses: actions/setup-node@v3 - name: Install frontend dependencies
with: run: bun install
node-version: '20' # Or specify any other version you want to use - name: Format frontend
run: bun run format
- name: Install Dependencies - name: Build frontend
run: npm install run: bun run build
- name: Format Frontend
run: npm run format
- name: Run i18next
run: npm run i18n:parse
- name: Check for Changes After Format
run: git diff --exit-code
- name: Build Frontend
run: npm run build

View file

@ -1,186 +0,0 @@
name: Integration Test
on:
push:
branches:
- main
- dev
pull_request:
branches:
- main
- dev
jobs:
cypress-run:
name: Run Cypress Integration Tests
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v4
- name: Build and run Compose Stack
run: |
docker compose up --detach --build
- name: Preload Ollama model
run: |
docker exec ollama ollama pull qwen:0.5b-chat-v1.5-q2_K
- name: Cypress run
uses: cypress-io/github-action@v6
with:
browser: chrome
wait-on: 'http://localhost:3000'
config: baseUrl=http://localhost:3000
- uses: actions/upload-artifact@v4
if: always()
name: Upload Cypress videos
with:
name: cypress-videos
path: cypress/videos
if-no-files-found: ignore
- name: Extract Compose logs
if: always()
run: |
docker compose logs > compose-logs.txt
- uses: actions/upload-artifact@v4
if: always()
name: Upload Compose logs
with:
name: compose-logs
path: compose-logs.txt
if-no-files-found: ignore
migration_test:
name: Run Migration Tests
runs-on: ubuntu-latest
services:
postgres:
image: postgres
env:
POSTGRES_PASSWORD: postgres
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
# mysql:
# image: mysql
# env:
# MYSQL_ROOT_PASSWORD: mysql
# MYSQL_DATABASE: mysql
# options: >-
# --health-cmd "mysqladmin ping -h localhost"
# --health-interval 10s
# --health-timeout 5s
# --health-retries 5
# ports:
# - 3306:3306
steps:
- name: Checkout Repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Set up uv
uses: yezz123/setup-uv@v4
with:
uv-venv: venv
- name: Activate virtualenv
run: |
. venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
- name: Install dependencies
run: |
uv pip install -r backend/requirements.txt
- name: Test backend with SQLite
id: sqlite
env:
WEBUI_SECRET_KEY: secret-key
GLOBAL_LOG_LEVEL: debug
run: |
cd backend
uvicorn main:app --port "8080" --forwarded-allow-ips '*' &
UVICORN_PID=$!
# Wait up to 20 seconds for the server to start
for i in {1..20}; do
curl -s http://localhost:8080/api/config > /dev/null && break
sleep 1
if [ $i -eq 20 ]; then
echo "Server failed to start"
kill -9 $UVICORN_PID
exit 1
fi
done
# Check that the server is still running after 5 seconds
sleep 5
if ! kill -0 $UVICORN_PID; then
echo "Server has stopped"
exit 1
fi
- name: Test backend with Postgres
if: success() || steps.sqlite.conclusion == 'failure'
env:
WEBUI_SECRET_KEY: secret-key
GLOBAL_LOG_LEVEL: debug
DATABASE_URL: postgresql://postgres:postgres@localhost:5432/postgres
run: |
cd backend
uvicorn main:app --port "8081" --forwarded-allow-ips '*' &
UVICORN_PID=$!
# Wait up to 20 seconds for the server to start
for i in {1..20}; do
curl -s http://localhost:8081/api/config > /dev/null && break
sleep 1
if [ $i -eq 20 ]; then
echo "Server failed to start"
kill -9 $UVICORN_PID
exit 1
fi
done
# Check that the server is still running after 5 seconds
sleep 5
if ! kill -0 $UVICORN_PID; then
echo "Server has stopped"
exit 1
fi
# - name: Test backend with MySQL
# if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure'
# env:
# WEBUI_SECRET_KEY: secret-key
# GLOBAL_LOG_LEVEL: debug
# DATABASE_URL: mysql://root:mysql@localhost:3306/mysql
# run: |
# cd backend
# uvicorn main:app --port "8083" --forwarded-allow-ips '*' &
# UVICORN_PID=$!
# # Wait up to 20 seconds for the server to start
# for i in {1..20}; do
# curl -s http://localhost:8083/api/config > /dev/null && break
# sleep 1
# if [ $i -eq 20 ]; then
# echo "Server failed to start"
# kill -9 $UVICORN_PID
# exit 1
# fi
# done
# # Check that the server is still running after 5 seconds
# sleep 5
# if ! kill -0 $UVICORN_PID; then
# echo "Server has stopped"
# exit 1
# fi

8
.gitignore vendored
View file

@ -166,7 +166,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/ #.idea/
# Logs # Logs
logs logs
@ -297,8 +297,4 @@ dist
.yarn/unplugged .yarn/unplugged
.yarn/build-state.yml .yarn/build-state.yml
.yarn/install-state.gz .yarn/install-state.gz
.pnp.* .pnp.*
# cypress artifacts
cypress/videos
cypress/screenshots

View file

@ -5,364 +5,6 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.123] - 2024-05-02
### Added
- **🎨 New Landing Page Design**: Refreshed design for a more modern look and optimized use of screen space.
- **📹 Youtube RAG Pipeline**: Introduces dedicated RAG pipeline for Youtube videos, enabling interaction with video transcriptions directly.
- **🔧 Enhanced Admin Panel**: Streamlined user management with options to add users directly or in bulk via CSV import.
- **👥 '@' Model Integration**: Easily switch to specific models during conversations; old collaborative chat feature phased out.
- **🌐 Language Enhancements**: Swedish translation added, plus improvements to German, Spanish, and the addition of Doge translation.
### Fixed
- **🗑️ Delete Chat Shortcut**: Addressed issue where shortcut wasn't functioning.
- **🖼️ Modal Closing Bug**: Resolved unexpected closure of modal when dragging from within.
- **✏️ Edit Button Styling**: Fixed styling inconsistency with edit buttons.
- **🌐 Image Generation Compatibility Issue**: Rectified image generation compatibility issue with third-party APIs.
- **📱 iOS PWA Icon Fix**: Corrected iOS PWA home screen icon shape.
- **🔍 Scroll Gesture Bug**: Adjusted gesture sensitivity to prevent accidental activation when scrolling through code on mobile; now requires scrolling from the leftmost side to open the sidebar.
### Changed
- **🔄 Unlimited Context Length**: Advanced settings now allow unlimited max context length (previously limited to 16000).
- **👑 Super Admin Assignment**: The first signup is automatically assigned a super admin role, unchangeable by other admins.
- **🛡️ Admin User Restrictions**: User action buttons from the admin panel are now disabled for users with admin roles.
- **🔝 Default Model Selector**: Set as default model option now exclusively available on the landing page.
## [0.1.122] - 2024-04-27
### Added
- **🌟 Enhanced RAG Pipeline**: Now with hybrid searching via 'BM25', reranking powered by 'CrossEncoder', and configurable relevance score thresholds.
- **🛢️ External Database Support**: Seamlessly connect to custom SQLite or Postgres databases using the 'DATABASE_URL' environment variable.
- **🌐 Remote ChromaDB Support**: Introducing the capability to connect to remote ChromaDB servers.
- **👨‍💼 Improved Admin Panel**: Admins can now conveniently check users' chat lists and last active status directly from the admin panel.
- **🎨 Splash Screen**: Introducing a loading splash screen for a smoother user experience.
- **🌍 Language Support Expansion**: Added support for Bangla (bn-BD), along with enhancements to Chinese, Spanish, and Ukrainian translations.
- **💻 Improved LaTeX Rendering Performance**: Enjoy faster rendering times for LaTeX equations.
- **🔧 More Environment Variables**: Explore additional environment variables in our documentation (https://docs.openwebui.com), including the 'ENABLE_LITELLM' option to manage memory usage.
### Fixed
- **🔧 Ollama Compatibility**: Resolved errors occurring when Ollama server version isn't an integer, such as SHA builds or RCs.
- **🐛 Various OpenAI API Issues**: Addressed several issues related to the OpenAI API.
- **🛑 Stop Sequence Issue**: Fixed the problem where the stop sequence with a backslash '\' was not functioning.
- **🔤 Font Fallback**: Corrected font fallback issue.
### Changed
- **⌨️ Prompt Input Behavior on Mobile**: Enter key prompt submission disabled on mobile devices for improved user experience.
## [0.1.121] - 2024-04-24
### Fixed
- **🔧 Translation Issues**: Addressed various translation discrepancies.
- **🔒 LiteLLM Security Fix**: Updated LiteLLM version to resolve a security vulnerability.
- **🖥️ HTML Tag Display**: Rectified the issue where the '< br >' tag wasn't displaying correctly.
- **🔗 WebSocket Connection**: Resolved the failure of WebSocket connection under HTTPS security for ComfyUI server.
- **📜 FileReader Optimization**: Implemented FileReader initialization per image in multi-file drag & drop to ensure reusability.
- **🏷️ Tag Display**: Corrected tag display inconsistencies.
- **📦 Archived Chat Styling**: Fixed styling issues in archived chat.
- **🔖 Safari Copy Button Bug**: Addressed the bug where the copy button failed to copy links in Safari.
## [0.1.120] - 2024-04-20
### Added
- **📦 Archive Chat Feature**: Easily archive chats with a new sidebar button, and access archived chats via the profile button > archived chats.
- **🔊 Configurable Text-to-Speech Endpoint**: Customize your Text-to-Speech experience with configurable OpenAI endpoints.
- **🛠️ Improved Error Handling**: Enhanced error message handling for connection failures.
- **⌨️ Enhanced Shortcut**: When editing messages, use ctrl/cmd+enter to save and submit, and esc to close.
- **🌐 Language Support**: Added support for Georgian and enhanced translations for Portuguese and Vietnamese.
### Fixed
- **🔧 Model Selector**: Resolved issue where default model selection was not saving.
- **🔗 Share Link Copy Button**: Fixed bug where the copy button wasn't copying links in Safari.
- **🎨 Light Theme Styling**: Addressed styling issue with the light theme.
## [0.1.119] - 2024-04-16
### Added
- **🌟 Enhanced RAG Embedding Support**: Ollama, and OpenAI models can now be used for RAG embedding model.
- **🔄 Seamless Integration**: Copy 'ollama run <model name>' directly from Ollama page to easily select and pull models.
- **🏷️ Tagging Feature**: Add tags to chats directly via the sidebar chat menu.
- **📱 Mobile Accessibility**: Swipe left and right on mobile to effortlessly open and close the sidebar.
- **🔍 Improved Navigation**: Admin panel now supports pagination for user list.
- **🌍 Additional Language Support**: Added Polish language support.
### Fixed
- **🌍 Language Enhancements**: Vietnamese and Spanish translations have been improved.
- **🔧 Helm Fixes**: Resolved issues with Helm trailing slash and manifest.json.
### Changed
- **🐳 Docker Optimization**: Updated docker image build process to utilize 'uv' for significantly faster builds compared to 'pip3'.
## [0.1.118] - 2024-04-10
### Added
- **🦙 Ollama and CUDA Images**: Added support for ':ollama' and ':cuda' tagged images.
- **👍 Enhanced Response Rating**: Now you can annotate your ratings for better feedback.
- **👤 User Initials Profile Photo**: User initials are now the default profile photo.
- **🔍 Update RAG Embedding Model**: Customize RAG embedding model directly in document settings.
- **🌍 Additional Language Support**: Added Turkish language support.
### Fixed
- **🔒 Share Chat Permission**: Resolved issue with chat sharing permissions.
- **🛠 Modal Close**: Modals can now be closed using the Esc key.
### Changed
- **🎨 Admin Panel Styling**: Refreshed styling for the admin panel.
- **🐳 Docker Image Build**: Updated docker image build process for improved efficiency.
## [0.1.117] - 2024-04-03
### Added
- 🗨️ **Local Chat Sharing**: Share chat links seamlessly between users.
- 🔑 **API Key Generation Support**: Generate secret keys to leverage Open WebUI with OpenAI libraries.
- 📄 **Chat Download as PDF**: Easily download chats in PDF format.
- 📝 **Improved Logging**: Enhancements to logging functionality.
- 📧 **Trusted Email Authentication**: Authenticate using a trusted email header.
### Fixed
- 🌷 **Enhanced Dutch Translation**: Improved translation for Dutch users.
- ⚪ **White Theme Styling**: Resolved styling issue with the white theme.
- 📜 **LaTeX Chat Screen Overflow**: Fixed screen overflow issue with LaTeX rendering.
- 🔒 **Security Patches**: Applied necessary security patches.
## [0.1.116] - 2024-03-31
### Added
- **🔄 Enhanced UI**: Model selector now conveniently located in the navbar, enabling seamless switching between multiple models during conversations.
- **🔍 Improved Model Selector**: Directly pull a model from the selector/Models now display detailed information for better understanding.
- **💬 Webhook Support**: Now compatible with Google Chat and Microsoft Teams.
- **🌐 Localization**: Korean translation (I18n) now available.
- **🌑 Dark Theme**: OLED dark theme introduced for reduced strain during prolonged usage.
- **🏷️ Tag Autocomplete**: Dropdown feature added for effortless chat tagging.
### Fixed
- **🔽 Auto-Scrolling**: Addressed OpenAI auto-scrolling issue.
- **🏷️ Tag Validation**: Implemented tag validation to prevent empty string tags.
- **🚫 Model Whitelisting**: Resolved LiteLLM model whitelisting issue.
- **✅ Spelling**: Corrected various spelling issues for improved readability.
## [0.1.115] - 2024-03-24
### Added
- **🔍 Custom Model Selector**: Easily find and select custom models with the new search filter feature.
- **🛑 Cancel Model Download**: Added the ability to cancel model downloads.
- **🎨 Image Generation ComfyUI**: Image generation now supports ComfyUI.
- **🌟 Updated Light Theme**: Updated the light theme for a fresh look.
- **🌍 Additional Language Support**: Now supporting Bulgarian, Italian, Portuguese, Japanese, and Dutch.
### Fixed
- **🔧 Fixed Broken Experimental GGUF Upload**: Resolved issues with experimental GGUF upload functionality.
### Changed
- **🔄 Vector Storage Reset Button**: Moved the reset vector storage button to document settings.
## [0.1.114] - 2024-03-20
### Added
- **🔗 Webhook Integration**: Now you can subscribe to new user sign-up events via webhook. Simply navigate to the admin panel > admin settings > webhook URL.
- **🛡️ Enhanced Model Filtering**: Alongside Ollama, OpenAI proxy model whitelisting, we've added model filtering functionality for LiteLLM proxy.
- **🌍 Expanded Language Support**: Spanish, Catalan, and Vietnamese languages are now available, with improvements made to others.
### Fixed
- **🔧 Input Field Spelling**: Resolved issue with spelling mistakes in input fields.
- **🖊️ Light Mode Styling**: Fixed styling issue with light mode in document adding.
### Changed
- **🔄 Language Sorting**: Languages are now sorted alphabetically by their code for improved organization.
## [0.1.113] - 2024-03-18
### Added
- 🌍 **Localization**: You can now change the UI language in Settings > General. We support Ukrainian, German, Farsi (Persian), Traditional and Simplified Chinese and French translations. You can help us to translate the UI into your language! More info in our [CONTRIBUTION.md](https://github.com/open-webui/open-webui/blob/main/docs/CONTRIBUTING.md#-translations-and-internationalization).
- 🎨 **System-wide Theme**: Introducing a new system-wide theme for enhanced visual experience.
### Fixed
- 🌑 **Dark Background on Select Fields**: Improved readability by adding a dark background to select fields, addressing issues on certain browsers/devices.
- **Multiple OPENAI_API_BASE_URLS Issue**: Resolved issue where multiple base URLs caused conflicts when one wasn't functioning.
- **RAG Encoding Issue**: Fixed encoding problem in RAG.
- **npm Audit Fix**: Addressed npm audit findings.
- **Reduced Scroll Threshold**: Improved auto-scroll experience by reducing the scroll threshold from 50px to 5px.
### Changed
- 🔄 **Sidebar UI Update**: Updated sidebar UI to feature a chat menu dropdown, replacing two icons for improved navigation.
## [0.1.112] - 2024-03-15
### Fixed
- 🗨️ Resolved chat malfunction after image generation.
- 🎨 Fixed various RAG issues.
- 🧪 Rectified experimental broken GGUF upload logic.
## [0.1.111] - 2024-03-10
### Added
- 🛡️ **Model Whitelisting**: Admins now have the ability to whitelist models for users with the 'user' role.
- 🔄 **Update All Models**: Added a convenient button to update all models at once.
- 📄 **Toggle PDF OCR**: Users can now toggle PDF OCR option for improved parsing performance.
- 🎨 **DALL-E Integration**: Introduced DALL-E integration for image generation alongside automatic1111.
- 🛠️ **RAG API Refactoring**: Refactored RAG logic and exposed its API, with additional documentation to follow.
### Fixed
- 🔒 **Max Token Settings**: Added max token settings for anthropic/claude-3-sonnet-20240229 (Issue #1094).
- 🔧 **Misalignment Issue**: Corrected misalignment of Edit and Delete Icons when Chat Title is Empty (Issue #1104).
- 🔄 **Context Loss Fix**: Resolved RAG losing context on model response regeneration with Groq models via API key (Issue #1105).
- 📁 **File Handling Bug**: Addressed File Not Found Notification when Dropping a Conversation Element (Issue #1098).
- 🖱️ **Dragged File Styling**: Fixed dragged file layover styling issue.
## [0.1.110] - 2024-03-06
### Added
- **🌐 Multiple OpenAI Servers Support**: Enjoy seamless integration with multiple OpenAI-compatible APIs, now supported natively.
### Fixed
- **🔍 OCR Issue**: Resolved PDF parsing issue caused by OCR malfunction.
- **🚫 RAG Issue**: Fixed the RAG functionality, ensuring it operates smoothly.
- **📄 "Add Docs" Model Button**: Addressed the non-functional behavior of the "Add Docs" model button.
## [0.1.109] - 2024-03-06
### Added
- **🔄 Multiple Ollama Servers Support**: Enjoy enhanced scalability and performance with support for multiple Ollama servers in a single WebUI. Load balancing features are now available, providing improved efficiency (#788, #278).
- **🔧 Support for Claude 3 and Gemini**: Responding to user requests, we've expanded our toolset to include Claude 3 and Gemini, offering a wider range of functionalities within our platform (#1064).
- **🔍 OCR Functionality for PDF Loader**: We've augmented our PDF loader with Optical Character Recognition (OCR) capabilities. Now, extract text from scanned documents and images within PDFs, broadening the scope of content processing (#1050).
### Fixed
- **🛠️ RAG Collection**: Implemented a dynamic mechanism to recreate RAG collections, ensuring users have up-to-date and accurate data (#1031).
- **📝 User Agent Headers**: Fixed issue of RAG web requests being sent with empty user_agent headers, reducing rejections from certain websites. Realistic headers are now utilized for these requests (#1024).
- **⏹️ Playground Cancel Functionality**: Introducing a new "Cancel" option for stopping Ollama generation in the Playground, enhancing user control and usability (#1006).
- **🔤 Typographical Error in 'ASSISTANT' Field**: Corrected a typographical error in the 'ASSISTANT' field within the GGUF model upload template for accuracy and consistency (#1061).
### Changed
- **🔄 Refactored Message Deletion Logic**: Streamlined message deletion process for improved efficiency and user experience, simplifying interactions within the platform (#1004).
- **⚠️ Deprecation of `OLLAMA_API_BASE_URL`**: Deprecated `OLLAMA_API_BASE_URL` environment variable; recommend using `OLLAMA_BASE_URL` instead. Refer to our documentation for further details.
## [0.1.108] - 2024-03-02
### Added
- **🎮 Playground Feature (Beta)**: Explore the full potential of the raw API through an intuitive UI with our new playground feature, accessible to admins. Simply click on the bottom name area of the sidebar to access it. The playground feature offers two modes text completion (notebook) and chat completion. As it's in beta, please report any issues you encounter.
- **🛠️ Direct Database Download for Admins**: Admins can now download the database directly from the WebUI via the admin settings.
- **🎨 Additional RAG Settings**: Customize your RAG process with the ability to edit the TOP K value. Navigate to Documents > Settings > General to make changes.
- **🖥️ UI Improvements**: Tooltips now available in the input area and sidebar handle. More tooltips will be added across other parts of the UI.
### Fixed
- Resolved input autofocus issue on mobile when the sidebar is open, making it easier to use.
- Corrected numbered list display issue in Safari (#963).
- Restricted user ability to delete chats without proper permissions (#993).
### Changed
- **Simplified Ollama Settings**: Ollama settings now don't require the `/api` suffix. You can now utilize the Ollama base URL directly, e.g., `http://localhost:11434`. Also, an `OLLAMA_BASE_URL` environment variable has been added.
- **Database Renaming**: Starting from this release, `ollama.db` will be automatically renamed to `webui.db`.
## [0.1.107] - 2024-03-01
### Added
- **🚀 Makefile and LLM Update Script**: Included Makefile and a script for LLM updates in the repository.
### Fixed
- Corrected issue where links in the settings modal didn't appear clickable (#960).
- Fixed problem with web UI port not taking effect due to incorrect environment variable name in run-compose.sh (#996).
- Enhanced user experience by displaying chat in browser title and enabling automatic scrolling to the bottom (#992).
### Changed
- Upgraded toast library from `svelte-french-toast` to `svelte-sonner` for a more polished UI.
- Enhanced accessibility with the addition of dark mode on the authentication page.
## [0.1.106] - 2024-02-27
### Added
- **🎯 Auto-focus Feature**: The input area now automatically focuses when initiating or opening a chat conversation.
### Fixed
- Corrected typo from "HuggingFace" to "Hugging Face" (Issue #924).
- Resolved bug causing errors in chat completion API calls to OpenAI due to missing "num_ctx" parameter (Issue #927).
- Fixed issues preventing text editing, selection, and cursor retention in the input field (Issue #940).
- Fixed a bug where defining an OpenAI-compatible API server using 'OPENAI_API_BASE_URL' containing 'openai' string resulted in hiding models not containing 'gpt' string from the model menu. (Issue #930)
## [0.1.105] - 2024-02-25
### Added
- **📄 Document Selection**: Now you can select and delete multiple documents at once for easier management.
### Changed
- **🏷️ Document Pre-tagging**: Simply click the "+" button at the top, enter tag names in the popup window, or select from a list of existing tags. Then, upload files with the added tags for streamlined organization.
## [0.1.104] - 2024-02-25
### Added
- **🔄 Check for Updates**: Keep your system current by checking for updates conveniently located in Settings > About.
- **🗑️ Automatic Tag Deletion**: Unused tags on the sidebar will now be deleted automatically with just a click.
### Changed
- **🎨 Modernized Styling**: Enjoy a refreshed look with updated styling for a more contemporary experience.
## [0.1.103] - 2024-02-25
### Added
- **🔗 Built-in LiteLLM Proxy**: Now includes LiteLLM proxy within Open WebUI for enhanced functionality.
- Easily integrate existing LiteLLM configurations using `-v /path/to/config.yaml:/app/backend/data/litellm/config.yaml` flag.
- When utilizing Docker container to run Open WebUI, ensure connections to localhost use `host.docker.internal`.
- **🖼️ Image Generation Enhancements**: Introducing Advanced Settings with Image Preview Feature.
- Customize image generation by setting the number of steps; defaults to A1111 value.
### Fixed
- Resolved issue with RAG scan halting document loading upon encountering unsupported MIME types or exceptions (Issue #866).
### Changed
- Ollama is no longer required to run Open WebUI.
- Access our comprehensive documentation at [Open WebUI Documentation](https://docs.openwebui.com/).
## [0.1.102] - 2024-02-22 ## [0.1.102] - 2024-02-22
### Added ### Added

View file

@ -1,128 +1,75 @@
# syntax=docker/dockerfile:1 # syntax=docker/dockerfile:1
# Initialize device type args
# use build args in the docker build commmand with --build-arg="BUILDARG=true"
ARG USE_CUDA=false
ARG USE_OLLAMA=false
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
ARG USE_CUDA_VER=cu121
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
ARG USE_RERANKING_MODEL=""
######## WebUI frontend ######## FROM node:alpine as build
FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
WORKDIR /app WORKDIR /app
# wget embedding model weight from alpine (does not exist from slim-buster)
RUN wget "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz" -O - | \
tar -xzf - -C /app
COPY package.json package-lock.json ./ COPY package.json package-lock.json ./
RUN npm ci RUN npm ci
COPY . . COPY . .
RUN npm run build RUN npm run build
######## WebUI backend ########
FROM python:3.11-slim-bookworm as base FROM python:3.11-slim-bookworm as base
# Use args ENV ENV=prod
ARG USE_CUDA ENV PORT ""
ARG USE_OLLAMA
ARG USE_CUDA_VER
ARG USE_EMBEDDING_MODEL
ARG USE_RERANKING_MODEL
## Basis ## ENV OLLAMA_API_BASE_URL "/ollama/api"
ENV ENV=prod \
PORT=8080 \
# pass build args to the build
USE_OLLAMA_DOCKER=${USE_OLLAMA} \
USE_CUDA_DOCKER=${USE_CUDA} \
USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
## Basis URL Config ## ENV OPENAI_API_BASE_URL ""
ENV OLLAMA_BASE_URL="/ollama" \ ENV OPENAI_API_KEY ""
OPENAI_API_BASE_URL=""
## API Key and Security Config ## ENV WEBUI_SECRET_KEY ""
ENV OPENAI_API_KEY="" \
WEBUI_SECRET_KEY="" \
SCARF_NO_ANALYTICS=true \
DO_NOT_TRACK=true \
ANONYMIZED_TELEMETRY=false
# Use locally bundled version of the LiteLLM cost map json ENV SCARF_NO_ANALYTICS true
# to avoid repetitive startup connections ENV DO_NOT_TRACK true
ENV LITELLM_LOCAL_MODEL_COST_MAP="True"
######## Preloaded models ########
# whisper TTS Settings
ENV WHISPER_MODEL="base"
ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
#### Other models ######################################################### # RAG Embedding Model Settings
## whisper TTS model settings ## # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
ENV WHISPER_MODEL="base" \ # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" # for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR
## RAG Embedding model settings ## ######## Preloaded models ########
ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
## Hugging Face download cache ##
ENV HF_HOME="/app/backend/data/cache/embedding/models"
#### Other models ##########################################################
WORKDIR /app/backend WORKDIR /app/backend
ENV HOME /root
RUN mkdir -p $HOME/.cache/chroma
RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
RUN if [ "$USE_OLLAMA" = "true" ]; then \
apt-get update && \
# Install pandoc and netcat
apt-get install -y --no-install-recommends pandoc netcat-openbsd && \
# for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# install helper tools
apt-get install -y --no-install-recommends curl && \
# install ollama
curl -fsSL https://ollama.com/install.sh | sh && \
# cleanup
rm -rf /var/lib/apt/lists/*; \
else \
apt-get update && \
# Install pandoc and netcat
apt-get install -y --no-install-recommends pandoc netcat-openbsd && \
# for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# cleanup
rm -rf /var/lib/apt/lists/*; \
fi
# install python dependencies # install python dependencies
COPY ./backend/requirements.txt ./requirements.txt COPY ./backend/requirements.txt ./requirements.txt
RUN pip3 install uv && \ RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir
if [ "$USE_CUDA" = "true" ]; then \ RUN pip3 install -r requirements.txt --no-cache-dir
# If you use CUDA the whisper and embedding model will be downloaded on first use
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
else \
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
fi
# Install pandoc and netcat
# RUN python -c "import pypandoc; pypandoc.download_pandoc()"
RUN apt-get update \
&& apt-get install -y pandoc netcat-openbsd \
&& rm -rf /var/lib/apt/lists/*
# preload embedding model
RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])"
# preload tts model
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
# copy embedding weight from build # copy embedding weight from build
# RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2 RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
# COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
# copy built frontend files # copy built frontend files
COPY --from=build /app/build /app/build COPY --from=build /app/build /app/build
@ -132,6 +79,4 @@ COPY --from=build /app/package.json /app/package.json
# copy backend files # copy backend files
COPY ./backend . COPY ./backend .
EXPOSE 8080 CMD [ "bash", "start.sh"]
CMD [ "bash", "start.sh"]

View file

@ -1,33 +0,0 @@
ifneq ($(shell which docker-compose 2>/dev/null),)
DOCKER_COMPOSE := docker-compose
else
DOCKER_COMPOSE := docker compose
endif
install:
$(DOCKER_COMPOSE) up -d
remove:
@chmod +x confirm_remove.sh
@./confirm_remove.sh
start:
$(DOCKER_COMPOSE) start
startAndBuild:
$(DOCKER_COMPOSE) up -d --build
stop:
$(DOCKER_COMPOSE) stop
update:
# Calls the LLM update script
chmod +x update_ollama_models.sh
@./update_ollama_models.sh
@git pull
$(DOCKER_COMPOSE) down
# Make sure the ollama-webui container is stopped before rebuilding
@docker stop open-webui || true
$(DOCKER_COMPOSE) up --build -d
$(DOCKER_COMPOSE) start

291
README.md
View file

@ -11,10 +11,12 @@
[![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s) [![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s)
[![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck) [![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck)
Open WebUI is an extensible, feature-rich, and user-friendly self-hosted WebUI designed to operate entirely offline. It supports various LLM runners, including Ollama and OpenAI-compatible APIs. For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/). User-friendly WebUI for LLMs, Inspired by ChatGPT
![Open WebUI Demo](./demo.gif) ![Open WebUI Demo](./demo.gif)
Also check our sibling project, [Open WebUI Community](https://openwebui.com/), where you can discover, download, and explore customized Modelfiles for Ollama! 🦙🔍
## Features ⭐ ## Features ⭐
- 🖥️ **Intuitive Interface**: Our chat interface takes inspiration from ChatGPT, ensuring a user-friendly experience. - 🖥️ **Intuitive Interface**: Our chat interface takes inspiration from ChatGPT, ensuring a user-friendly experience.
@ -25,28 +27,22 @@ Open WebUI is an extensible, feature-rich, and user-friendly self-hosted WebUI d
- 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience. - 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience.
- 🌈 **Theme Customization**: Choose from a variety of themes to personalize your Open WebUI experience.
- 💻 **Code Syntax Highlighting**: Enjoy enhanced code readability with our syntax highlighting feature. - 💻 **Code Syntax Highlighting**: Enjoy enhanced code readability with our syntax highlighting feature.
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction. - ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with the groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using `#` command in the prompt. In its alpha phase, occasional issues may arise as we actively refine and enhance this feature to ensure optimal performance and reliability. - 📚 **Local RAG Integration**: Dive into the future of chat interactions with the groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using `#` command in the prompt. In its alpha phase, occasional issues may arise as we actively refine and enhance this feature to ensure optimal performance and reliability.
- 🔍 **RAG Embedding Support**: Change the RAG embedding model directly in document settings, enhancing document processing. This feature supports Ollama and OpenAI models.
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by the URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions. - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by the URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
- 📜 **Prompt Preset Support**: Instantly access preset prompts using the `/` command in the chat input. Load predefined conversation starters effortlessly and expedite your interactions. Effortlessly import prompts through [Open WebUI Community](https://openwebui.com/) integration. - 📜 **Prompt Preset Support**: Instantly access preset prompts using the `/` command in the chat input. Load predefined conversation starters effortlessly and expedite your interactions. Effortlessly import prompts through [Open WebUI Community](https://openwebui.com/) integration.
- 👍👎 **RLHF Annotation**: Empower your messages by rating them with thumbs up and thumbs down, followed by the option to provide textual feedback, facilitating the creation of datasets for Reinforcement Learning from Human Feedback (RLHF). Utilize your messages to train or fine-tune models, all while ensuring the confidentiality of locally saved data. - 👍👎 **RLHF Annotation**: Empower your messages by rating them with thumbs up and thumbs down, facilitating the creation of datasets for Reinforcement Learning from Human Feedback (RLHF). Utilize your messages to train or fine-tune models, all while ensuring the confidentiality of locally saved data.
- 🏷️ **Conversation Tagging**: Effortlessly categorize and locate specific chats for quick reference and streamlined data collection. - 🏷️ **Conversation Tagging**: Effortlessly categorize and locate specific chats for quick reference and streamlined data collection.
- 📥🗑️ **Download/Delete Models**: Easily download or remove models directly from the web UI. - 📥🗑️ **Download/Delete Models**: Easily download or remove models directly from the web UI.
- 🔄 **Update All Ollama Models**: Easily update locally installed models all at once with a convenient button, streamlining model management.
- ⬆️ **GGUF File Model Creation**: Effortlessly create Ollama models by uploading GGUF files directly from the web UI. Streamlined process with options to upload from your machine or download GGUF files from Hugging Face. - ⬆️ **GGUF File Model Creation**: Effortlessly create Ollama models by uploading GGUF files directly from the web UI. Streamlined process with options to upload from your machine or download GGUF files from Hugging Face.
- 🤖 **Multiple Model Support**: Seamlessly switch between different chat models for diverse interactions. - 🤖 **Multiple Model Support**: Seamlessly switch between different chat models for diverse interactions.
@ -59,102 +55,162 @@ Open WebUI is an extensible, feature-rich, and user-friendly self-hosted WebUI d
- 💬 **Collaborative Chat**: Harness the collective intelligence of multiple models by seamlessly orchestrating group conversations. Use the `@` command to specify the model, enabling dynamic and diverse dialogues within your chat interface. Immerse yourself in the collective intelligence woven into your chat environment. - 💬 **Collaborative Chat**: Harness the collective intelligence of multiple models by seamlessly orchestrating group conversations. Use the `@` command to specify the model, enabling dynamic and diverse dialogues within your chat interface. Immerse yourself in the collective intelligence woven into your chat environment.
- 🗨️ **Local Chat Sharing**: Generate and share chat links seamlessly between users, enhancing collaboration and communication. - 🤝 **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**.
- 🔄 **Regeneration History Access**: Easily revisit and explore your entire regeneration history. - 🔄 **Regeneration History Access**: Easily revisit and explore your entire regeneration history.
- 📜 **Chat History**: Effortlessly access and manage your conversation history. - 📜 **Chat History**: Effortlessly access and manage your conversation history.
- 📬 **Archive Chats**: Effortlessly store away completed conversations with LLMs for future reference, maintaining a tidy and clutter-free chat interface while allowing for easy retrieval and reference.
- 📤📥 **Import/Export Chat History**: Seamlessly move your chat data in and out of the platform. - 📤📥 **Import/Export Chat History**: Seamlessly move your chat data in and out of the platform.
- 🗣️ **Voice Input Support**: Engage with your model through voice interactions; enjoy the convenience of talking to your model directly. Additionally, explore the option for sending voice input automatically after 3 seconds of silence for a streamlined experience. - 🗣️ **Voice Input Support**: Engage with your model through voice interactions; enjoy the convenience of talking to your model directly. Additionally, explore the option for sending voice input automatically after 3 seconds of silence for a streamlined experience.
- 🔊 **Configurable Text-to-Speech Endpoint**: Customize your Text-to-Speech experience with configurable OpenAI endpoints.
- ⚙️ **Fine-Tuned Control with Advanced Parameters**: Gain a deeper level of control by adjusting parameters such as temperature and defining your system prompts to tailor the conversation to your specific preferences and needs. - ⚙️ **Fine-Tuned Control with Advanced Parameters**: Gain a deeper level of control by adjusting parameters such as temperature and defining your system prompts to tailor the conversation to your specific preferences and needs.
- 🎨🤖 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API (local), ComfyUI (local), and DALL-E, enriching your chat experience with dynamic visual content.
- 🤝 **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**.
- ✨ **Multiple OpenAI-Compatible API Support**: Seamlessly integrate and customize various OpenAI-compatible APIs, enhancing the versatility of your chat interactions.
- 🔑 **API Key Generation Support**: Generate secret keys to leverage Open WebUI with OpenAI libraries, simplifying integration and development.
- 🔗 **External Ollama Server Connection**: Seamlessly link to an external Ollama server hosted on a different address by configuring the environment variable. - 🔗 **External Ollama Server Connection**: Seamlessly link to an external Ollama server hosted on a different address by configuring the environment variable.
- 🔀 **Multiple Ollama Instance Load Balancing**: Effortlessly distribute chat requests across multiple Ollama instances for enhanced performance and reliability.
- 👥 **Multi-User Management**: Easily oversee and administer users via our intuitive admin panel, streamlining user management processes.
- 🔗 **Webhook Integration**: Subscribe to new user sign-up events via webhook (compatible with Google Chat and Microsoft Teams), providing real-time notifications and automation capabilities.
- 🛡️ **Model Whitelisting**: Admins can whitelist models for users with the 'user' role, enhancing security and access control.
- 📧 **Trusted Email Authentication**: Authenticate using a trusted email header, adding an additional layer of security and authentication.
- 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators. - 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
- 🔒 **Backend Reverse Proxy Support**: Bolster security through direct communication between Open WebUI backend and Ollama. This key feature eliminates the need to expose Ollama over LAN. Requests made to the '/ollama/api' route from the web UI are seamlessly redirected to Ollama from the backend, enhancing overall system security. - 🔒 **Backend Reverse Proxy Support**: Bolster security through direct communication between Open WebUI backend and Ollama. This key feature eliminates the need to expose Ollama over LAN. Requests made to the '/ollama/api' route from the web UI are seamlessly redirected to Ollama from the backend, enhancing overall system security.
- 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
- 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates and new features. - 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates and new features.
## 🔗 Also Check Out Open WebUI Community! ## 🔗 Also Check Out Open WebUI Community!
Don't forget to explore our sibling project, [Open WebUI Community](https://openwebui.com/), where you can discover, download, and explore customized Modelfiles. Open WebUI Community offers a wide range of exciting possibilities for enhancing your chat interactions with Open WebUI! 🚀 Don't forget to explore our sibling project, [Open WebUI Community](https://openwebui.com/), where you can discover, download, and explore customized Modelfiles. Open WebUI Community offers a wide range of exciting possibilities for enhancing your chat interactions with Ollama! 🚀
## How to Install 🚀 ## How to Install 🚀
> [!NOTE] 🌟 **Important Note on User Roles and Privacy:**
> Please note that for certain Docker environments, additional configurations might be needed. If you encounter any connection issues, our detailed guide on [Open WebUI Documentation](https://docs.openwebui.com/) is ready to assist you.
### Quick Start with Docker 🐳 - **Admin Creation:** The very first account to sign up on Open WebUI will be granted **Administrator privileges**. This account will have comprehensive control over the platform, including user management and system settings.
> [!WARNING] - **User Registrations:** All subsequent users signing up will initially have their accounts set to **Pending** status by default. These accounts will require approval from the Administrator to gain access to the platform functionalities.
> When using Docker to install Open WebUI, make sure to include the `-v open-webui:/app/backend/data` in your Docker command. This step is crucial as it ensures your database is properly mounted and prevents any loss of data.
> [!TIP] - **Privacy and Data Security:** We prioritize your privacy and data security above all. Please be reassured that all data entered into Open WebUI is stored locally on your device. Our system is designed to be privacy-first, ensuring that no external requests are made, and your data does not leave your local environment. We are committed to maintaining the highest standards of data privacy and security, ensuring that your information remains confidential and under your control.
> If you wish to utilize Open WebUI with Ollama included or CUDA acceleration, we recommend utilizing our official images tagged with either `:cuda` or `:ollama`. To enable CUDA, you must install the [Nvidia CUDA container toolkit](https://docs.nvidia.com/dgx/nvidia-container-runtime-upgrade/) on your Linux/WSL system.
**If Ollama is on your computer**, use this command: ### Steps to Install Open WebUI
```bash #### Before You Begin
docker run -d -p 3000:8080 --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
```
**If Ollama is on a Different Server**, use this command: 1. **Installing Docker:**
To connect to Ollama on another server, change the `OLLAMA_BASE_URL` to the server's URL: - **For Windows and Mac Users:**
```bash - Download Docker Desktop from [Docker's official website](https://www.docker.com/products/docker-desktop).
docker run -d -p 3000:8080 -e OLLAMA_BASE_URL=https://example.com -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main - Follow the installation instructions provided on the website. After installation, open Docker Desktop to ensure it's running properly.
```
After installation, you can access Open WebUI at [http://localhost:3000](http://localhost:3000). Enjoy! 😄 - **For Ubuntu and Other Linux Users:**
- Open your terminal.
- Set up your Docker apt repository according to the [Docker documentation](https://docs.docker.com/engine/install/ubuntu/#install-using-the-repository)
- Update your package index:
```bash
sudo apt-get update
```
- Install Docker using the following command:
```bash
sudo apt-get install docker-ce docker-ce-cli containerd.io
```
- Verify the Docker installation with:
```bash
sudo docker run hello-world
```
This command downloads a test image and runs it in a container, which prints an informational message.
#### Open WebUI: Server Connection Error 2. **Ensure You Have the Latest Version of Ollama:**
If you're experiencing connection issues, its often due to the WebUI docker container not being able to reach the Ollama server at 127.0.0.1:11434 (host.docker.internal:11434) inside the container . Use the `--network=host` flag in your docker command to resolve this. Note that the port changes from 3000 to 8080, resulting in the link: `http://localhost:8080`. - Download the latest version from [https://ollama.com/](https://ollama.com/).
**Example Docker Command**: 3. **Verify Ollama Installation:**
- After installing Ollama, check if it's working by visiting [http://127.0.0.1:11434/](http://127.0.0.1:11434/) in your web browser. Remember, the port number might be different for you.
```bash #### Installing with Docker 🐳
docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
```
### Other Installation Methods - **Important:** When using Docker to install Open WebUI, make sure to include the `-v open-webui:/app/backend/data` in your Docker command. This step is crucial as it ensures your database is properly mounted and prevents any loss of data.
We offer various installation alternatives, including non-Docker methods, Docker Compose, Kustomize, and Helm. Visit our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/) or join our [Discord community](https://discord.gg/5rJgQTnV4s) for comprehensive guidance. - **If Ollama is on your computer**, use this command:
### Troubleshooting ```bash
docker run -d -p 3000:8080 --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
```
Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s). - **To build the container yourself**, follow these steps:
### Keeping Your Docker Installation Up-to-Date ```bash
docker build -t open-webui .
docker run -d -p 3000:8080 --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always open-webui
```
- After installation, you can access Open WebUI at [http://localhost:3000](http://localhost:3000).
#### Using Ollama on a Different Server
- To connect to Ollama on another server, change the `OLLAMA_API_BASE_URL` to the server's URL:
```bash
docker run -d -p 3000:8080 -e OLLAMA_API_BASE_URL=https://example.com/api -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
```
Or for a self-built container:
```bash
docker build -t open-webui .
docker run -d -p 3000:8080 -e OLLAMA_API_BASE_URL=https://example.com/api -v open-webui:/app/backend/data --name open-webui --restart always open-webui
```
### Installing Ollama and Open WebUI Together
#### Using Docker Compose
- If you don't have Ollama yet, use Docker Compose for easy installation. Run this command:
```bash
docker compose up -d --build
```
- **For GPU Support:** Use an additional Docker Compose file:
```bash
docker compose -f docker-compose.yaml -f docker-compose.gpu.yaml up -d --build
```
- **To Expose Ollama API:** Use another Docker Compose file:
```bash
docker compose -f docker-compose.yaml -f docker-compose.api.yaml up -d --build
```
#### Using `run-compose.sh` Script (Linux or Docker-Enabled WSL2 on Windows)
- Give execute permission to the script:
```bash
chmod +x run-compose.sh
```
- For CPU-only container:
```bash
./run-compose.sh
```
- For GPU support (read the note about GPU compatibility):
```bash
./run-compose.sh --enable-gpu
```
- To build the latest local version, add `--build`:
```bash
./run-compose.sh --enable-gpu --build
```
### Alternative Installation Methods
For other ways to install, like using Kustomize or Helm, check out [INSTALLATION.md](/INSTALLATION.md). Join our [Open WebUI Discord community](https://discord.gg/5rJgQTnV4s) for more help and information.
### Updating your Docker Installation
In case you want to update your local Docker installation to the latest version, you can do it with [Watchtower](https://containrrr.dev/watchtower/): In case you want to update your local Docker installation to the latest version, you can do it with [Watchtower](https://containrrr.dev/watchtower/):
@ -166,11 +222,104 @@ In the last part of the command, replace `open-webui` with your container name i
### Moving from Ollama WebUI to Open WebUI ### Moving from Ollama WebUI to Open WebUI
Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/migration/). Given recent name changes, the docker image has been renamed. Additional steps are required to update for those people that used Ollama WebUI previously and want to start using the new images.
## What's Next? 🌟 #### Updating to Open WebUI without keeping your data
Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/). If you want to update to the new image but don't want to keep any previous data like conversations, prompts, documents, etc. you can perform the following steps:
```bash
docker rm -f ollama-webui
docker pull ghcr.io/open-webui/open-webui:main
[insert the equivalent command that you used to install with the new Docker image name]
docker volume rm ollama-webui
```
For example, for local installation it would be `docker run -d -p 3000:8080 --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main`. For other installation commands, check the relevant parts of this README document.
#### Migrating your contents from Ollama WebUI to Open WebUI
If you want to update to the new image migrating all your previous settings like conversations, prompts, documents, etc. you can perform the following steps:
```bash
docker rm -f ollama-webui
docker pull ghcr.io/open-webui/open-webui:main
# Creates a new volume and uses a temporary container to copy from one volume to another as per https://github.com/moby/moby/issues/31154#issuecomment-360531460
docker volume create --name open-webui
docker run --rm -v ollama-webui:/from -v open-webui:/to alpine ash -c "cd /from ; cp -av . /to"
[insert the equivalent command that you used to install with the new Docker image name]
```
Once you verify that all the data has been migrated you can erase the old volumen using the following command:
```bash
docker volume rm ollama-webui
```
## How to Install Without Docker
While we strongly recommend using our convenient Docker container installation for optimal support, we understand that some situations may require a non-Docker setup, especially for development purposes. Please note that non-Docker installations are not officially supported, and you might need to troubleshoot on your own.
### Project Components
Open WebUI consists of two primary components: the frontend and the backend (which serves as a reverse proxy, handling static frontend files, and additional features). Both need to be running concurrently for the development environment.
> [!IMPORTANT]
> The backend is required for proper functionality
### Requirements 📦
- 🐰 [Bun](https://bun.sh) >= 1.0.21 or 🐢 [Node.js](https://nodejs.org/en) >= 20.10
- 🐍 [Python](https://python.org) >= 3.11
### Build and Install 🛠️
Run the following commands to install:
```sh
git clone https://github.com/open-webui/open-webui.git
cd open-webui/
# Copying required .env file
cp -RPp .env.example .env
# Building Frontend Using Node
npm i
npm run build
# or Building Frontend Using Bun
# bun install
# bun run build
# Serving Frontend with the Backend
cd ./backend
pip install -r requirements.txt -U
sh start.sh
```
You should have Open WebUI up and running at http://localhost:8080/. Enjoy! 😄
## Troubleshooting
See [TROUBLESHOOTING.md](/TROUBLESHOOTING.md) for information on how to troubleshoot and/or join our [Open WebUI Discord community](https://discord.gg/5rJgQTnV4s).
## What's Next? 🚀
### Roadmap 📝
Here are some exciting tasks on our roadmap:
- 🔊 **Local Text-to-Speech Integration**: Seamlessly incorporate text-to-speech functionality directly within the platform, allowing for a smoother and more immersive user experience.
- 🛡️ **Granular Permissions and User Groups**: Empower administrators to finely control access levels and group users according to their roles and responsibilities. This feature ensures robust security measures and streamlined management of user privileges, enhancing overall platform functionality.
- 🔄 **Function Calling**: Empower your interactions by running code directly within the chat. Execute functions and commands effortlessly, enhancing the functionality of your conversations.
- ⚙️ **Custom Python Backend Actions**: Empower your Open WebUI by creating or downloading custom Python backend actions. Unleash the full potential of your web interface with tailored actions that suit your specific needs, enhancing functionality and versatility.
- 🔧 **Fine-tune Model (LoRA)**: Fine-tune your model directly from the user interface. This feature allows for precise customization and optimization of the chat experience to better suit your needs and preferences.
- 🧠 **Long-Term Memory**: Witness the power of persistent memory in our agents. Enjoy conversations that feel continuous as agents remember and reference past interactions, creating a more cohesive and personalized user experience.
- 🧪 **Research-Centric Features**: Empower researchers in the fields of LLM and HCI with a comprehensive web UI for conducting user studies. Stay tuned for ongoing feature enhancements (e.g., surveys, analytics, and participant tracking) to facilitate their research.
- 📈 **User Study Tools**: Providing specialized tools, like heat maps and behavior tracking modules, to empower researchers in capturing and analyzing user behavior patterns with precision and accuracy.
- 📚 **Enhanced Documentation**: Elevate your setup and customization experience with improved, comprehensive documentation.
Feel free to contribute and help us make Open WebUI even better! 🙌
## Supporters ✨ ## Supporters ✨
@ -193,16 +342,6 @@ This project is licensed under the [MIT License](LICENSE) - see the [LICENSE](LI
If you have any questions, suggestions, or need assistance, please open an issue or join our If you have any questions, suggestions, or need assistance, please open an issue or join our
[Open WebUI Discord community](https://discord.gg/5rJgQTnV4s) to connect with us! 🤝 [Open WebUI Discord community](https://discord.gg/5rJgQTnV4s) to connect with us! 🤝
## Star History
<a href="https://star-history.com/#open-webui/open-webui&Date">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date" />
</picture>
</a>
--- ---
Created by [Timothy J. Baek](https://github.com/tjbck) - Let's make Open WebUI even more amazing together! 💪 Created by [Timothy J. Baek](https://github.com/tjbck) - Let's make Open Web UI even more amazing together! 💪

View file

@ -4,7 +4,7 @@
The Open WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues. The Open WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues.
- **How it Works**: The Open WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Open WebUI backend via `/ollama` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_BASE_URL` environment variable. Therefore, a request made to `/ollama` in the WebUI is effectively the same as making a request to `OLLAMA_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_BASE_URL/api/tags` in the backend. - **How it Works**: The Open WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Open WebUI backend via `/ollama/api` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_API_BASE_URL` environment variable. Therefore, a request made to `/ollama/api` in the WebUI is effectively the same as making a request to `OLLAMA_API_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_API_BASE_URL/tags` in the backend.
- **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer. - **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer.
@ -15,7 +15,7 @@ If you're experiencing connection issues, its often due to the WebUI docker c
**Example Docker Command**: **Example Docker Command**:
```bash ```bash
docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_API_BASE_URL=http://127.0.0.1:11434/api --name open-webui --restart always ghcr.io/open-webui/open-webui:main
``` ```
### General Connection Errors ### General Connection Errors
@ -25,8 +25,8 @@ docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=
**Troubleshooting Steps**: **Troubleshooting Steps**:
1. **Verify Ollama URL Format**: 1. **Verify Ollama URL Format**:
- When running the Web UI container, ensure the `OLLAMA_BASE_URL` is correctly set. (e.g., `http://192.168.1.1:11434` for different host setups). - When running the Web UI container, ensure the `OLLAMA_API_BASE_URL` is correctly set, including the `/api` suffix. (e.g., `http://192.168.1.1:11434/api` for different host setups).
- In the Open WebUI, navigate to "Settings" > "General". - In the Open WebUI, navigate to "Settings" > "General".
- Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]` (e.g., `http://localhost:11434`). - Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]/api` (e.g., `http://localhost:11434/api`), including the `/api` suffix.
By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord. By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord.

View file

@ -4,11 +4,4 @@ _old
uploads uploads
.ipynb_checkpoints .ipynb_checkpoints
*.db *.db
_test _test
!/data
/data/*
!/data/litellm
/data/litellm/*
!data/litellm/config.yaml
!data/config.json

7
backend/.gitignore vendored
View file

@ -6,11 +6,6 @@ uploads
*.db *.db
_test _test
Pipfile Pipfile
!/data data/*
/data/*
!/data/litellm
/data/litellm/*
!data/litellm/config.yaml
!data/config.json !data/config.json
.webui_secret_key .webui_secret_key

View file

@ -1,5 +1,4 @@
import os import os
import logging
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
Request, Request,
@ -10,19 +9,8 @@ from fastapi import (
File, File,
Form, Form,
) )
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
from pydantic import BaseModel
import requests
import hashlib
from pathlib import Path
import json
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
@ -33,20 +21,7 @@ from utils.utils import (
) )
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
from config import ( from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR
SRC_LOG_LEVELS,
CACHE_DIR,
UPLOAD_DIR,
WHISPER_MODEL,
WHISPER_MODEL_DIR,
WHISPER_MODEL_AUTO_UPDATE,
DEVICE_TYPE,
AUDIO_OPENAI_API_BASE_URL,
AUDIO_OPENAI_API_KEY,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
@ -58,108 +33,12 @@ app.add_middleware(
) )
app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL @app.post("/transcribe")
app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
# setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
log.info(f"whisper_device_type: {whisper_device_type}")
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
class OpenAIConfigUpdateForm(BaseModel):
url: str
key: str
@app.get("/config")
async def get_openai_config(user=Depends(get_admin_user)):
return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
}
@app.post("/config/update")
async def update_openai_config(
form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
):
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key
return {
"status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
}
@app.post("/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
body = await request.body()
name = hashlib.sha256(body).hexdigest()
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
# Check if the file already exists in the cache
if file_path.is_file():
return FileResponse(file_path)
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URL}/audio/speech",
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Save the streaming content to a file
with open(file_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
with open(file_body_path, "w") as f:
json.dump(json.loads(body.decode("utf-8")), f)
# Return the saved file
return FileResponse(file_path)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r != None else 500,
detail=error_detail,
)
@app.post("/transcriptions")
def transcribe( def transcribe(
file: UploadFile = File(...), file: UploadFile = File(...),
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
log.info(f"file.content_type: {file.content_type}") print(file.content_type)
if file.content_type not in ["audio/mpeg", "audio/wav"]: if file.content_type not in ["audio/mpeg", "audio/wav"]:
raise HTTPException( raise HTTPException(
@ -175,27 +54,15 @@ def transcribe(
f.write(contents) f.write(contents)
f.close() f.close()
whisper_kwargs = { model = WhisperModel(
"model_size_or_path": WHISPER_MODEL, WHISPER_MODEL,
"device": whisper_device_type, device="auto",
"compute_type": "int8", compute_type="int8",
"download_root": WHISPER_MODEL_DIR, download_root=WHISPER_MODEL_DIR,
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE, )
}
log.debug(f"whisper_kwargs: {whisper_kwargs}")
try:
model = WhisperModel(**whisper_kwargs)
except:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
whisper_kwargs["local_files_only"] = False
model = WhisperModel(**whisper_kwargs)
segments, info = model.transcribe(file_path, beam_size=5) segments, info = model.transcribe(file_path, beam_size=5)
log.info( print(
"Detected language '%s' with probability %f" "Detected language '%s' with probability %f"
% (info.language, info.language_probability) % (info.language, info.language_probability)
) )
@ -205,7 +72,7 @@ def transcribe(
return {"text": transcript.strip()} return {"text": transcript.strip()}
except Exception as e: except Exception as e:
log.exception(e) print(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,

View file

@ -0,0 +1,172 @@
from pathlib import Path
import ast
import builtins
from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
UploadFile,
File,
Form,
)
from fastapi.middleware.cors import CORSMiddleware
from apps.functions.security import ALLOWED_MODULES, ALLOWED_BUILTINS, custom_import
from utils.utils import get_current_user, get_admin_user
from config import FUNCTIONS_DIR
from constants import ERROR_MESSAGES
from pydantic import BaseModel
from typing import Optional
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def get_status():
return {"status": True}
class FunctionForm(BaseModel):
name: str
content: str
@app.post("/add")
def add_function(
form_data: FunctionForm,
user=Depends(get_admin_user),
):
try:
filename = f"{FUNCTIONS_DIR}/{form_data.name}.py"
if not Path(filename).exists():
with open(filename, "w") as file:
file.write(form_data.content)
return f"{form_data.name}.py" in list(
map(lambda x: x.name, Path(FUNCTIONS_DIR).rglob("./*"))
)
else:
raise Exception("Function already exists")
except Exception as e:
print(e)
return False
@app.post("/update")
def update_function(
form_data: FunctionForm,
user=Depends(get_admin_user),
):
try:
filename = f"{FUNCTIONS_DIR}/{form_data.name}.py"
if Path(filename).exists():
with open(filename, "w") as file:
file.write(form_data.content)
return f"{form_data.name}.py" in list(
map(lambda x: x.name, Path(FUNCTIONS_DIR).rglob("./*"))
)
else:
raise Exception("Function does not exist")
except Exception as e:
print(e)
return False
@app.get("/check/{function}")
def check_function(
function: str,
user=Depends(get_admin_user),
):
filename = f"{FUNCTIONS_DIR}/{function}.py"
# Check if the function file exists
if not Path(filename).is_file():
raise HTTPException(status_code=404, detail="Function not found")
# Read the code from the file
with open(filename, "r") as file:
code = file.read()
return {"name": function, "content": code}
@app.get("/list")
def list_functions(
user=Depends(get_admin_user),
):
files = list(map(lambda x: x.name, Path(FUNCTIONS_DIR).rglob("./*")))
return files
def validate_imports(code):
try:
tree = ast.parse(code)
except SyntaxError as e:
raise HTTPException(status_code=400, detail=f"Syntax error in function: {e}")
for node in ast.walk(tree):
if isinstance(node, ast.Import):
module_names = [alias.name for alias in node.names]
elif isinstance(node, ast.ImportFrom):
module_names = [node.module]
else:
continue
for name in module_names:
if name not in ALLOWED_MODULES:
raise HTTPException(
status_code=400, detail=f"Import of module {name} is not allowed"
)
@app.post("/exec/{function}")
def exec_function(
function: str,
kwargs: Optional[dict] = None,
user=Depends(get_current_user),
):
filename = f"{FUNCTIONS_DIR}/{function}.py"
# Check if the function file exists
if not Path(filename).is_file():
raise HTTPException(status_code=404, detail="Function not found")
# Read the code from the file
with open(filename, "r") as file:
code = file.read()
validate_imports(code)
try:
# Execute the code within a restricted namespace
namespace = {name: getattr(builtins, name) for name in ALLOWED_BUILTINS}
namespace["__import__"] = custom_import
exec(code, namespace)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Function: {e}")
# Check if the function exists in the namespace
if "main" not in namespace or not callable(namespace["main"]):
raise HTTPException(status_code=400, detail="Invalid function")
try:
# Execute the function with provided kwargs
result = namespace["main"](kwargs) if kwargs else namespace["main"]()
return result
except Exception as e:
raise HTTPException(status_code=400, detail=f"Function: {e}")

View file

@ -0,0 +1,140 @@
ALLOWED_MODULES = {
"pydantic",
"math",
"json",
"time",
"datetime",
"requests",
} # Add allowed modules here
def custom_import(name, globals=None, locals=None, fromlist=(), level=0):
if name in ALLOWED_MODULES:
return __import__(name, globals, locals, fromlist, level)
raise ImportError(f"Import of module {name} is not allowed")
# Define a restricted set of builtins
ALLOWED_BUILTINS = {
"ArithmeticError",
"AssertionError",
"AttributeError",
"BaseException",
"BufferError",
"BytesWarning",
"DeprecationWarning",
"EOFError",
"Ellipsis",
"EnvironmentError",
"Exception",
"False",
"FloatingPointError",
"FutureWarning",
"GeneratorExit",
"IOError",
"ImportError",
"ImportWarning",
"IndentationError",
"IndexError",
"KeyError",
"KeyboardInterrupt",
"LookupError",
"MemoryError",
"NameError",
"None",
"NotImplemented",
"NotImplementedError",
"OSError",
"OverflowError",
"PendingDeprecationWarning",
"ReferenceError",
"RuntimeError",
"RuntimeWarning",
"StopIteration",
"SyntaxError",
"SyntaxWarning",
"SystemError",
"SystemExit",
"TabError",
"True",
"TypeError",
"UnboundLocalError",
"UnicodeDecodeError",
"UnicodeEncodeError",
"UnicodeError",
"UnicodeTranslateError",
"UnicodeWarning",
"UserWarning",
"ValueError",
"Warning",
"ZeroDivisionError",
"__build_class__",
"__debug__",
"__import__",
"abs",
"all",
"any",
"ascii",
"bin",
"bool",
"bytearray",
"bytes",
"callable",
"chr",
"classmethod",
"compile",
"complex",
"delattr",
"dict",
"dir",
"divmod",
"enumerate",
"eval",
"exec",
"filter",
"float",
"format",
"frozenset",
"getattr",
"globals",
"hasattr",
"hash",
"hex",
"id",
"input",
"int",
"isinstance",
"issubclass",
"iter",
"len",
"list",
"locals",
"map",
"max",
"memoryview",
"min",
"next",
"object",
"oct",
"open",
"ord",
"pow",
"print",
"property",
"range",
"repr",
"reversed",
"round",
"set",
"setattr",
"slice",
"sorted",
"staticmethod",
"str",
"sum",
"super",
"tuple",
"type",
"vars",
"zip",
}

View file

@ -18,38 +18,10 @@ from utils.utils import (
get_current_user, get_current_user,
get_admin_user, get_admin_user,
) )
from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from pathlib import Path from config import AUTOMATIC1111_BASE_URL
import mimetypes
import uuid
import base64
import json
import logging
from config import (
SRC_LOG_LEVELS,
CACHE_DIR,
IMAGE_GENERATION_ENGINE,
ENABLE_IMAGE_GENERATION,
AUTOMATIC1111_BASE_URL,
COMFYUI_BASE_URL,
IMAGES_OPENAI_API_BASE_URL,
IMAGES_OPENAI_API_KEY,
IMAGE_GENERATION_MODEL,
IMAGE_SIZE,
IMAGE_STEPS,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
@ -60,116 +32,49 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.ENGINE = IMAGE_GENERATION_ENGINE
app.state.ENABLED = ENABLE_IMAGE_GENERATION
app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.MODEL = IMAGE_GENERATION_MODEL
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
app.state.IMAGE_SIZE = "512x512"
app.state.IMAGE_SIZE = IMAGE_SIZE @app.get("/enabled", response_model=bool)
app.state.IMAGE_STEPS = IMAGE_STEPS async def get_enable_status(request: Request, user=Depends(get_admin_user)):
return app.state.ENABLED
@app.get("/config") @app.get("/enabled/toggle", response_model=bool)
async def get_config(request: Request, user=Depends(get_admin_user)): async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} try:
r = requests.head(app.state.AUTOMATIC1111_BASE_URL)
app.state.ENABLED = not app.state.ENABLED
return app.state.ENABLED
except Exception as e:
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
class ConfigUpdateForm(BaseModel): class UrlUpdateForm(BaseModel):
engine: str url: str
enabled: bool
@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.ENGINE = form_data.engine
app.state.ENABLED = form_data.enabled
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
class EngineUrlUpdateForm(BaseModel):
AUTOMATIC1111_BASE_URL: Optional[str] = None
COMFYUI_BASE_URL: Optional[str] = None
@app.get("/url") @app.get("/url")
async def get_engine_url(user=Depends(get_admin_user)): async def get_openai_url(user=Depends(get_admin_user)):
return { return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
}
@app.post("/url/update") @app.post("/url/update")
async def update_engine_url( async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
):
if form_data.AUTOMATIC1111_BASE_URL == None: if form_data.url == "":
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else: else:
url = form_data.AUTOMATIC1111_BASE_URL.strip("/") app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
try:
r = requests.head(url)
app.state.AUTOMATIC1111_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.COMFYUI_BASE_URL == None:
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else:
url = form_data.COMFYUI_BASE_URL.strip("/")
try:
r = requests.head(url)
app.state.COMFYUI_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return { return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
"status": True, "status": True,
} }
class OpenAIConfigUpdateForm(BaseModel):
url: str
key: str
@app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)):
return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
}
@app.post("/openai/config/update")
async def update_openai_config(
form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
):
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key
return {
"status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
}
class ImageSizeUpdateForm(BaseModel): class ImageSizeUpdateForm(BaseModel):
size: str size: str
@ -197,82 +102,25 @@ async def update_image_size(
) )
class ImageStepsUpdateForm(BaseModel):
steps: int
@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
@app.post("/steps/update")
async def update_image_size(
form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
):
if form_data.steps >= 0:
app.state.IMAGE_STEPS = form_data.steps
return {
"IMAGE_STEPS": app.state.IMAGE_STEPS,
"status": True,
}
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
)
@app.get("/models") @app.get("/models")
def get_models(user=Depends(get_current_user)): def get_models(user=Depends(get_current_user)):
try: try:
if app.state.ENGINE == "openai": r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
return [ models = r.json()
{"id": "dall-e-2", "name": "DALL·E 2"}, return models
{"id": "dall-e-3", "name": "DALL·E 3"},
]
elif app.state.ENGINE == "comfyui":
r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
info = r.json()
return list(
map(
lambda model: {"id": model, "name": model},
info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
)
)
else:
r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
)
models = r.json()
return list(
map(
lambda model: {"id": model["title"], "name": model["model_name"]},
models,
)
)
except Exception as e: except Exception as e:
app.state.ENABLED = False raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models/default") @app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)): async def get_default_model(user=Depends(get_admin_user)):
try: try:
if app.state.ENGINE == "openai": r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} options = r.json()
elif app.state.ENGINE == "comfyui":
return {"model": app.state.MODEL if app.state.MODEL else ""} return {"model": options["sd_model_checkpoint"]}
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
return {"model": options["sd_model_checkpoint"]}
except Exception as e: except Exception as e:
app.state.ENABLED = False raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
class UpdateModelForm(BaseModel): class UpdateModelForm(BaseModel):
@ -280,23 +128,16 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str): def set_model_handler(model: str):
if app.state.ENGINE == "openai": r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
app.state.MODEL = model options = r.json()
return app.state.MODEL
if app.state.ENGINE == "comfyui":
app.state.MODEL = model
return app.state.MODEL
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
if model != options["sd_model_checkpoint"]: if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model options["sd_model_checkpoint"] = model
r = requests.post( r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
) )
return options return options
@app.post("/models/default/update") @app.post("/models/default/update")
@ -311,196 +152,42 @@ class GenerateImageForm(BaseModel):
model: Optional[str] = None model: Optional[str] = None
prompt: str prompt: str
n: int = 1 n: int = 1
size: Optional[str] = None size: str = "512x512"
negative_prompt: Optional[str] = None negative_prompt: Optional[str] = None
def save_b64_image(b64_str):
try:
image_id = str(uuid.uuid4())
if "," in b64_str:
header, encoded = b64_str.split(",", 1)
mime_type = header.split(";")[0]
img_data = base64.b64decode(encoded)
image_format = mimetypes.guess_extension(mime_type)
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
with open(file_path, "wb") as f:
f.write(img_data)
return image_filename
else:
image_filename = f"{image_id}.png"
file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
img_data = base64.b64decode(b64_str)
# Write the image data to a file
with open(file_path, "wb") as f:
f.write(img_data)
return image_filename
except Exception as e:
log.exception(f"Error saving image: {e}")
return None
def save_url_image(url):
image_id = str(uuid.uuid4())
try:
r = requests.get(url)
r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"]
image_format = mimetypes.guess_extension(mime_type)
if not image_format:
raise ValueError("Could not determine image type from MIME type")
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
with open(file_path, "wb") as image_file:
for chunk in r.iter_content(chunk_size=8192):
image_file.write(chunk)
return image_filename
else:
log.error(f"Url does not point to an image.")
return None
except Exception as e:
log.exception(f"Error saving image: {e}")
return None
@app.post("/generations") @app.post("/generations")
def generate_image( def generate_image(
form_data: GenerateImageForm, form_data: GenerateImageForm,
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) print(form_data)
r = None
try: try:
if app.state.ENGINE == "openai": if form_data.model:
set_model_handler(form_data.model)
headers = {} width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
data = { data = {
"model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", "prompt": form_data.prompt,
"prompt": form_data.prompt, "batch_size": form_data.n,
"n": form_data.n, "width": width,
"size": form_data.size if form_data.size else app.state.IMAGE_SIZE, "height": height,
"response_format": "b64_json", }
}
r = requests.post( if form_data.negative_prompt != None:
url=f"{app.state.OPENAI_API_BASE_URL}/images/generations", data["negative_prompt"] = form_data.negative_prompt
json=data,
headers=headers,
)
r.raise_for_status() print(data)
res = r.json()
images = [] r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
for image in res["data"]: json=data,
image_filename = save_b64_image(image["b64_json"]) )
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump(data, f)
return images
elif app.state.ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
"width": width,
"height": height,
"n": form_data.n,
}
if app.state.IMAGE_STEPS != None:
data["steps"] = app.state.IMAGE_STEPS
if form_data.negative_prompt != None:
data["negative_prompt"] = form_data.negative_prompt
data = ImageGenerationPayload(**data)
res = comfyui_generate_image(
app.state.MODEL,
data,
user.id,
app.state.COMFYUI_BASE_URL,
)
log.debug(f"res: {res}")
images = []
for image in res["data"]:
image_filename = save_url_image(image["url"])
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump(data.model_dump(exclude_none=True), f)
log.debug(f"images: {images}")
return images
else:
if form_data.model:
set_model_handler(form_data.model)
data = {
"prompt": form_data.prompt,
"batch_size": form_data.n,
"width": width,
"height": height,
}
if app.state.IMAGE_STEPS != None:
data["steps"] = app.state.IMAGE_STEPS
if form_data.negative_prompt != None:
data["negative_prompt"] = form_data.negative_prompt
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
)
res = r.json()
log.debug(f"res: {res}")
images = []
for image in res["images"]:
image_filename = save_b64_image(image)
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump({**data, "info": res["info"]}, f)
return images
return r.json()
except Exception as e: except Exception as e:
error = e print(e)
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
if r != None:
data = r.json()
if "error" in data:
error = data["error"]["message"]
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))

View file

@ -1,234 +0,0 @@
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
import random
import logging
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
from pydantic import BaseModel
from typing import Optional
COMFYUI_DEFAULT_PROMPT = """
{
"3": {
"inputs": {
"seed": 0,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "model.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "Negative Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
"""
def queue_prompt(prompt, client_id, base_url):
log.info("queue_prompt")
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
req = urllib.request.Request(f"{base_url}/prompt", data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type, base_url):
log.info("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
return response.read()
def get_image_url(filename, subfolder, folder_type, base_url):
log.info("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
return f"{base_url}/view?{url_values}"
def get_history(prompt_id, base_url):
log.info("get_history")
with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
return json.loads(response.read())
def get_images(ws, prompt, client_id, base_url):
prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
output_images = []
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "executing":
data = message["data"]
if data["node"] is None and data["prompt_id"] == prompt_id:
break # Execution is done
else:
continue # previews are binary data
history = get_history(prompt_id, base_url)[prompt_id]
for o in history["outputs"]:
for node_id in history["outputs"]:
node_output = history["outputs"][node_id]
if "images" in node_output:
for image in node_output["images"]:
url = get_image_url(
image["filename"], image["subfolder"], image["type"], base_url
)
output_images.append({"url": url})
return {"data": output_images}
class ImageGenerationPayload(BaseModel):
prompt: str
negative_prompt: Optional[str] = ""
steps: Optional[int] = None
seed: Optional[int] = None
width: int
height: int
n: int = 1
def comfyui_generate_image(
model: str, payload: ImageGenerationPayload, client_id, base_url
):
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
comfyui_prompt["5"]["inputs"]["width"] = payload.width
comfyui_prompt["5"]["inputs"]["height"] = payload.height
# set the text prompt for our positive CLIPTextEncode
comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
if payload.steps:
comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
comfyui_prompt["3"]["inputs"]["seed"] = (
payload.seed if payload.seed else random.randint(0, 18446744073709551614)
)
try:
ws = websocket.WebSocket()
ws.connect(f"{ws_url}/ws?clientId={client_id}")
log.info("WebSocket connection established.")
except Exception as e:
log.exception(f"Failed to connect to WebSocket server: {e}")
return None
try:
images = get_images(ws, comfyui_prompt, client_id, base_url)
except Exception as e:
log.exception(f"Error while receiving images: {e}")
images = None
ws.close()
return images

View file

@ -1,372 +0,0 @@
import sys
from fastapi import FastAPI, Depends, HTTPException
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
import logging
from fastapi import FastAPI, Request, Depends, status, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
import time
import requests
from pydantic import BaseModel, ConfigDict
from typing import Optional, List
from utils.utils import get_verified_user, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS, ENV
from constants import MESSAGES
import os
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import (
ENABLE_LITELLM,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
DATA_DIR,
LITELLM_PROXY_PORT,
LITELLM_PROXY_HOST,
)
from litellm.utils import get_llm_provider
import asyncio
import subprocess
import yaml
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml"
with open(LITELLM_CONFIG_DIR, "r") as file:
litellm_config = yaml.safe_load(file)
app.state.ENABLE = ENABLE_LITELLM
app.state.CONFIG = litellm_config
# Global variable to store the subprocess reference
background_process = None
CONFLICT_ENV_VARS = [
# Uvicorn uses PORT, so LiteLLM might use it as well
"PORT",
# LiteLLM uses DATABASE_URL for Prisma connections
"DATABASE_URL",
]
async def run_background_process(command):
global background_process
log.info("run_background_process")
try:
# Log the command to be executed
log.info(f"Executing command: {command}")
# Filter environment variables known to conflict with litellm
env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS}
# Execute the command and create a subprocess
process = await asyncio.create_subprocess_exec(
*command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)
background_process = process
log.info("Subprocess started successfully.")
# Capture STDERR for debugging purposes
stderr_output = await process.stderr.read()
stderr_text = stderr_output.decode().strip()
if stderr_text:
log.info(f"Subprocess STDERR: {stderr_text}")
# log.info output line by line
async for line in process.stdout:
log.info(line.decode().strip())
# Wait for the process to finish
returncode = await process.wait()
log.info(f"Subprocess exited with return code {returncode}")
except Exception as e:
log.error(f"Failed to start subprocess: {e}")
raise # Optionally re-raise the exception if you want it to propagate
async def start_litellm_background():
log.info("start_litellm_background")
# Command to run in the background
command = [
"litellm",
"--port",
str(LITELLM_PROXY_PORT),
"--host",
LITELLM_PROXY_HOST,
"--telemetry",
"False",
"--config",
LITELLM_CONFIG_DIR,
]
await run_background_process(command)
async def shutdown_litellm_background():
log.info("shutdown_litellm_background")
global background_process
if background_process:
background_process.terminate()
await background_process.wait() # Ensure the process has terminated
log.info("Subprocess terminated")
background_process = None
@app.on_event("startup")
async def startup_event():
log.info("startup_event")
# TODO: Check config.yaml file and create one
asyncio.create_task(start_litellm_background())
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.get("/")
async def get_status():
return {"status": True}
async def restart_litellm():
"""
Endpoint to restart the litellm background service.
"""
log.info("Requested restart of litellm service.")
try:
# Shut down the existing process if it is running
await shutdown_litellm_background()
log.info("litellm service shutdown complete.")
# Restart the background service
asyncio.create_task(start_litellm_background())
log.info("litellm service restart complete.")
return {
"status": "success",
"message": "litellm service restarted successfully.",
}
except Exception as e:
log.info(f"Error restarting litellm service: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
@app.get("/restart")
async def restart_litellm_handler(user=Depends(get_admin_user)):
return await restart_litellm()
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return app.state.CONFIG
class LiteLLMConfigForm(BaseModel):
general_settings: Optional[dict] = None
litellm_settings: Optional[dict] = None
model_list: Optional[List[dict]] = None
router_settings: Optional[dict] = None
model_config = ConfigDict(protected_namespaces=())
@app.post("/config/update")
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
app.state.CONFIG = form_data.model_dump(exclude_none=True)
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return app.state.CONFIG
@app.get("/models")
@app.get("/v1/models")
async def get_models(user=Depends(get_current_user)):
if app.state.ENABLE:
while not background_process:
await asyncio.sleep(0.1)
url = f"http://localhost:{LITELLM_PROXY_PORT}/v1"
r = None
try:
r = requests.request(method="GET", url=f"{url}/models")
r.raise_for_status()
data = r.json()
if app.state.ENABLE_MODEL_FILTER:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
return data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
return {
"data": [
{
"id": model["model_name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in app.state.CONFIG["model_list"]
],
"object": "list",
}
else:
return {
"data": [],
"object": "list",
}
@app.get("/model/info")
async def get_model_list(user=Depends(get_admin_user)):
return {"data": app.state.CONFIG["model_list"]}
class AddLiteLLMModelForm(BaseModel):
model_name: str
litellm_params: dict
model_config = ConfigDict(protected_namespaces=())
@app.post("/model/new")
async def add_model_to_config(
form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
):
try:
get_llm_provider(model=form_data.model_name)
app.state.CONFIG["model_list"].append(form_data.model_dump())
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
class DeleteLiteLLMModelForm(BaseModel):
id: str
@app.post("/model/delete")
async def delete_model_from_config(
form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user)
):
app.state.CONFIG["model_list"] = [
model
for model in app.state.CONFIG["model_list"]
if model["model_name"] != form_data.id
]
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return {"message": MESSAGES.MODEL_DELETED(form_data.id)}
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
body = await request.body()
url = f"http://localhost:{LITELLM_PROXY_PORT}"
target_url = f"{url}/{path}"
headers = {}
# headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
else:
response_data = r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,127 @@
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import requests
import json
from pydantic import BaseModel
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
import aiohttp
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
class UrlUpdateForm(BaseModel):
url: str
@app.post("/url/update")
async def update_ollama_api_url(
form_data: UrlUpdateForm, user=Depends(get_current_user)
):
if user and user.role == "admin":
app.state.OLLAMA_API_BASE_URL = form_data.url
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
# async def fetch_sse(method, target_url, body, headers):
# async with aiohttp.ClientSession() as session:
# try:
# async with session.request(
# method, target_url, data=body, headers=headers
# ) as response:
# print(response.status)
# async for line in response.content:
# yield line
# except Exception as e:
# print(e)
# error_detail = "Open WebUI: Server Connection Error"
# yield json.dumps({"error": error_detail, "message": str(e)}).encode()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
print(target_url)
body = await request.body()
headers = dict(request.headers)
if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin":
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
headers.pop("Host", None)
headers.pop("Authorization", None)
headers.pop("Origin", None)
headers.pop("Referer", None)
session = aiohttp.ClientSession()
response = None
try:
response = await session.request(
request.method, target_url, data=body, headers=headers
)
print(response)
if not response.ok:
data = await response.json()
print(data)
response.raise_for_status()
async def generate():
async for line in response.content:
print(line)
yield line
await session.close()
return StreamingResponse(generate(), response.status)
except Exception as e:
print(e)
error_detail = "Open WebUI: Server Connection Error"
if response is not None:
try:
res = await response.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
await session.close()
raise HTTPException(
status_code=response.status if response else 500,
detail=error_detail,
)

View file

@ -3,11 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
import requests import requests
import aiohttp
import asyncio
import json import json
import logging
from pydantic import BaseModel from pydantic import BaseModel
@ -19,23 +15,11 @@ from utils.utils import (
get_verified_user, get_verified_user,
get_admin_user, get_admin_user,
) )
from config import ( from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
SRC_LOG_LEVELS,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
CACHE_DIR,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
)
from typing import List, Optional
import hashlib import hashlib
from pathlib import Path from pathlib import Path
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -45,278 +29,132 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.OPENAI_API_KEY = OPENAI_API_KEY
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.MODELS = {}
@app.middleware("http") class UrlUpdateForm(BaseModel):
async def check_url(request: Request, call_next): url: str
if len(app.state.MODELS) == 0:
await get_all_models()
else:
pass
response = await call_next(request)
return response
class UrlsUpdateForm(BaseModel): class KeyUpdateForm(BaseModel):
urls: List[str] key: str
class KeysUpdateForm(BaseModel): @app.get("/url")
keys: List[str] async def get_openai_url(user=Depends(get_admin_user)):
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
@app.get("/urls") @app.post("/url/update")
async def get_openai_urls(user=Depends(get_admin_user)): async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} app.state.OPENAI_API_BASE_URL = form_data.url
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
@app.post("/urls/update") @app.get("/key")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): async def get_openai_key(user=Depends(get_admin_user)):
await get_all_models() return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
app.state.OPENAI_API_BASE_URLS = form_data.urls
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
@app.get("/keys") @app.post("/key/update")
async def get_openai_keys(user=Depends(get_admin_user)): async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)):
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} app.state.OPENAI_API_KEY = form_data.key
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
app.state.OPENAI_API_KEYS = form_data.keys
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
@app.post("/audio/speech") @app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)): async def speech(request: Request, user=Depends(get_verified_user)):
idx = None target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
body = await request.body()
name = hashlib.sha256(body).hexdigest()
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
# Check if the file already exists in the cache
if file_path.is_file():
return FileResponse(file_path)
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
try: try:
idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") print("openai")
body = await request.body() r = requests.post(
name = hashlib.sha256(body).hexdigest() url=target_url,
data=body,
headers=headers,
stream=True,
)
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") r.raise_for_status()
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
# Check if the file already exists in the cache # Save the streaming content to a file
if file_path.is_file(): with open(file_path, "wb") as f:
return FileResponse(file_path) for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
headers = {} with open(file_body_path, "w") as f:
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}" json.dump(json.loads(body.decode("utf-8")), f)
headers["Content-Type"] = "application/json"
r = None # Return the saved file
try: return FileResponse(file_path)
r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech",
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Save the streaming content to a file
with open(file_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
with open(file_body_path, "w") as f:
json.dump(json.loads(body.decode("utf-8")), f)
# Return the saved file
return FileResponse(file_path)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
except ValueError:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
async def fetch_url(url, key):
try:
headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
return await response.json()
except Exception as e: except Exception as e:
# Handle connection error here print(e)
log.error(f"Connection error: {e}") error_detail = "Open WebUI: Server Connection Error"
return None if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status_code, detail=error_detail)
def merge_models_lists(model_lists):
log.info(f"merge_models_lists {model_lists}")
merged_list = []
for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend(
[
{**model, "urlIdx": idx}
for model in models
if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
or "gpt" in model["id"]
]
)
return merged_list
async def get_all_models():
log.info("get_all_models()")
if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "":
models = {"data": []}
else:
tasks = [
fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
]
responses = await asyncio.gather(*tasks)
log.info(f"get_all_models:responses() {responses}")
models = {
"data": merge_models_lists(
list(
map(
lambda response: (
response["data"]
if (response and "data" in response)
else (response if isinstance(response, list) else None)
),
responses,
)
)
)
}
log.info(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
if url_idx == None:
models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER:
if user.role == "user":
models["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
models["data"],
)
)
return models
return models
else:
url = app.state.OPENAI_API_BASE_URLS[url_idx]
r = None
try:
r = requests.request(method="GET", url=f"{url}/models")
r.raise_for_status()
response_data = r.json()
if "api.openai.com" in url:
response_data["data"] = list(
filter(lambda model: "gpt" in model["id"], response_data["data"])
)
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)): async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
idx = 0 target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
print(target_url, app.state.OPENAI_API_KEY)
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
body = await request.body() body = await request.body()
# TODO: Remove below after gpt-4-vision fix from Open AI # TODO: Remove below after gpt-4-vision fix from Open AI
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
try: try:
body = body.decode("utf-8") body = body.decode("utf-8")
body = json.loads(body) body = json.loads(body)
idx = app.state.MODELS[body.get("model")]["urlIdx"]
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model # This is a workaround until OpenAI fixes the issue with this model
if body.get("model") == "gpt-4-vision-preview": if body.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in body: if "max_tokens" not in body:
body["max_tokens"] = 4000 body["max_tokens"] = 4000
log.debug("Modified body_dict:", body) print("Modified body_dict:", body)
# Fix for ChatGPT calls failing because the num_ctx key is in body
if "num_ctx" in body:
# If 'num_ctx' is in the dictionary, delete it
# Leaving it there generates an error with the
# OpenAI API (Feb 2024)
del body["num_ctx"]
# Convert the modified body back to JSON # Convert the modified body back to JSON
body = json.dumps(body) body = json.dumps(body)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e) print("Error loading request body into a dictionary:", e)
url = app.state.OPENAI_API_BASE_URLS[idx]
key = app.state.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}"
if key == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
headers = {} headers = {}
headers["Authorization"] = f"Bearer {key}" headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
r = None
try: try:
r = requests.request( r = requests.request(
method=request.method, method=request.method,
@ -336,19 +174,31 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers=dict(r.headers), headers=dict(r.headers),
) )
else: else:
# For non-SSE, read the response and return it
# response_data = (
# r.json()
# if r.headers.get("Content-Type", "")
# == "application/json"
# else r.text
# )
response_data = r.json() response_data = r.json()
if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
response_data["data"] = list(
filter(lambda model: "gpt" in model["id"], response_data["data"])
)
return response_data return response_data
except Exception as e: except Exception as e:
log.exception(e) print(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" error_detail = f"External: {res['error']}"
except: except:
error_detail = f"External: {e}" error_detail = f"External: {e}"
raise HTTPException( raise HTTPException(status_code=r.status_code, detail=error_detail)
status_code=r.status_code if r else 500, detail=error_detail
)

View file

@ -8,19 +8,19 @@ from fastapi import (
Form, Form,
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import os, shutil, logging, re import os, shutil
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from chromadb.utils.batch_utils import create_batches from sentence_transformers import SentenceTransformer
from chromadb.utils import embedding_functions
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
WebBaseLoader, WebBaseLoader,
TextLoader, TextLoader,
PyPDFLoader, PyPDFLoader,
CSVLoader, CSVLoader,
BSHTMLLoader,
Docx2txtLoader, Docx2txtLoader,
UnstructuredEPubLoader, UnstructuredEPubLoader,
UnstructuredWordDocumentLoader, UnstructuredWordDocumentLoader,
@ -28,22 +28,15 @@ from langchain_community.document_loaders import (
UnstructuredXMLLoader, UnstructuredXMLLoader,
UnstructuredRSTLoader, UnstructuredRSTLoader,
UnstructuredExcelLoader, UnstructuredExcelLoader,
YoutubeLoader,
) )
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
import validators
import urllib.parse
import socket
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
import mimetypes import mimetypes
import uuid import uuid
import json import json
import sentence_transformers
from apps.web.models.documents import ( from apps.web.models.documents import (
Documents, Documents,
@ -51,15 +44,6 @@ from apps.web.models.documents import (
DocumentResponse, DocumentResponse,
) )
from apps.rag.utils import (
get_model_path,
get_embedding_function,
query_doc,
query_doc_with_hybrid_search,
query_collection,
query_collection_with_hybrid_search,
)
from utils.misc import ( from utils.misc import (
calculate_sha256, calculate_sha256,
calculate_sha256_string, calculate_sha256_string,
@ -67,108 +51,44 @@ from utils.misc import (
extract_folders_after_data_docs, extract_folders_after_data_docs,
) )
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
from config import ( from config import (
SRC_LOG_LEVELS,
UPLOAD_DIR, UPLOAD_DIR,
DOCS_DIR, DOCS_DIR,
RAG_TOP_K,
RAG_RELEVANCE_THRESHOLD,
RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_DEVICE_TYPE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
ENABLE_RAG_HYBRID_SEARCH,
RAG_RERANKING_MODEL,
PDF_EXTRACT_IMAGES,
RAG_RERANKING_MODEL_AUTO_UPDATE,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY,
DEVICE_TYPE,
CHROMA_CLIENT, CHROMA_CLIENT,
CHUNK_SIZE, CHUNK_SIZE,
CHUNK_OVERLAP, CHUNK_OVERLAP,
RAG_TEMPLATE, RAG_TEMPLATE,
ENABLE_LOCAL_WEB_FETCH,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
log = logging.getLogger(__name__) #
log.setLevel(SRC_LOG_LEVELS["RAG"]) # if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer(
# model_name_or_path=RAG_EMBEDDING_MODEL,
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# )
app = FastAPI() app = FastAPI()
app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.sentence_transformer_ef = (
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
def update_embedding_model(
embedding_model: str,
update_model: bool = False,
):
if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model),
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
else:
app.state.sentence_transformer_ef = None
def update_reranking_model(
reranking_model: str,
update_model: bool = False,
):
if reranking_model:
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, update_model),
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
)
else:
app.state.sentence_transformer_rf = None
update_embedding_model(
app.state.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
) )
update_reranking_model(
app.state.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
origins = ["*"] origins = ["*"]
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=origins,
@ -182,10 +102,37 @@ class CollectionNameForm(BaseModel):
collection_name: Optional[str] = "test" collection_name: Optional[str] = "test"
class UrlForm(CollectionNameForm): class StoreWebForm(CollectionNameForm):
url: str url: str
def store_data_in_vector_db(data, collection_name) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
)
docs = text_splitter.split_documents(data)
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
try:
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
return True
except Exception as e:
print(e)
if e.__class__.__name__ == "UniqueConstraintError":
return True
return False
@app.get("/") @app.get("/")
async def get_status(): async def get_status():
return { return {
@ -193,121 +140,46 @@ async def get_status():
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE, "template": app.state.RAG_TEMPLATE,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.RAG_RERANKING_MODEL,
} }
@app.get("/embedding") @app.get("/embedding/model")
async def get_embedding_config(user=Depends(get_admin_user)): async def get_embedding_model(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.OPENAI_API_BASE_URL,
"key": app.state.OPENAI_API_KEY,
},
} }
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
class OpenAIConfigForm(BaseModel):
url: str
key: str
class EmbeddingModelUpdateForm(BaseModel): class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
embedding_engine: str
embedding_model: str embedding_model: str
@app.post("/embedding/update") @app.post("/embedding/model/update")
async def update_embedding_config( async def update_embedding_model(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
): ):
log.info( app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
) )
try:
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config != None:
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.OPENAI_API_KEY = form_data.openai_config.key
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
return {
"status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.OPENAI_API_BASE_URL,
"key": app.state.OPENAI_API_KEY,
},
}
except Exception as e:
log.exception(f"Problem updating embedding model: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(e),
)
class RerankingModelUpdateForm(BaseModel):
reranking_model: str
@app.post("/reranking/update")
async def update_reranking_config(
form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
)
try:
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
return {
"status": True,
"reranking_model": app.state.RAG_RERANKING_MODEL,
}
except Exception as e:
log.exception(f"Problem updating reranking model: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"chunk": { }
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
}, @app.get("/chunk")
async def get_chunk_params(user=Depends(get_admin_user)):
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
} }
@ -316,24 +188,17 @@ class ChunkParamUpdateForm(BaseModel):
chunk_overlap: int chunk_overlap: int
class ConfigUpdateForm(BaseModel): @app.post("/chunk/update")
pdf_extract_images: bool async def update_chunk_params(
chunk: ChunkParamUpdateForm form_data: ChunkParamUpdateForm, user=Depends(get_admin_user)
):
app.state.CHUNK_SIZE = form_data.chunk_size
@app.post("/config/update") app.state.CHUNK_OVERLAP = form_data.chunk_overlap
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
app.state.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
return { return {
"status": True, "status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, "chunk_size": app.state.CHUNK_SIZE,
"chunk": { "chunk_overlap": app.state.CHUNK_OVERLAP,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
},
} }
@ -345,73 +210,40 @@ async def get_rag_template(user=Depends(get_current_user)):
} }
@app.get("/query/settings") class RAGTemplateForm(BaseModel):
async def get_query_settings(user=Depends(get_admin_user)): template: str
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
}
class QuerySettingsForm(BaseModel): @app.post("/template/update")
k: Optional[int] = None async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
r: Optional[float] = None # TODO: check template requirements
template: Optional[str] = None app.state.RAG_TEMPLATE = (
hybrid: Optional[bool] = None form_data.template if form_data.template != "" else RAG_TEMPLATE
)
return {"status": True, "template": app.state.RAG_TEMPLATE}
@app.post("/query/settings/update")
async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
app.state.TOP_K = form_data.k if form_data.k else 4
app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
}
class QueryDocForm(BaseModel): class QueryDocForm(BaseModel):
collection_name: str collection_name: str
query: str query: str
k: Optional[int] = None k: Optional[int] = 4
r: Optional[float] = None
hybrid: Optional[bool] = None
@app.post("/query/doc") @app.post("/query/doc")
def query_doc_handler( def query_doc(
form_data: QueryDocForm, form_data: QueryDocForm,
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.ENABLE_RAG_HYBRID_SEARCH: # if you use docker use the model from the environment variable
return query_doc_with_hybrid_search( collection = CHROMA_CLIENT.get_collection(
collection_name=form_data.collection_name, name=form_data.collection_name,
query=form_data.query, embedding_function=app.state.sentence_transformer_ef,
embedding_function=app.state.EMBEDDING_FUNCTION, )
k=form_data.k if form_data.k else app.state.TOP_K, result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
reranking_function=app.state.sentence_transformer_rf, return result
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
)
else:
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
)
except Exception as e: except Exception as e:
log.exception(e) print(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),
@ -421,195 +253,104 @@ def query_doc_handler(
class QueryCollectionsForm(BaseModel): class QueryCollectionsForm(BaseModel):
collection_names: List[str] collection_names: List[str]
query: str query: str
k: Optional[int] = None k: Optional[int] = 4
r: Optional[float] = None
hybrid: Optional[bool] = None
def merge_and_sort_query_results(query_results, k):
# Initialize lists to store combined data
combined_ids = []
combined_distances = []
combined_metadatas = []
combined_documents = []
# Combine data from each dictionary
for data in query_results:
combined_ids.extend(data["ids"][0])
combined_distances.extend(data["distances"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_documents.extend(data["documents"][0])
# Create a list of tuples (distance, id, metadata, document)
combined = list(
zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
)
# Sort the list based on distances
combined.sort(key=lambda x: x[0])
# Unzip the sorted list
sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_ids = list(sorted_ids)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
sorted_documents = list(sorted_documents)[:k]
# Create the output dictionary
merged_query_results = {
"ids": [sorted_ids],
"distances": [sorted_distances],
"metadatas": [sorted_metadatas],
"documents": [sorted_documents],
"embeddings": None,
"uris": None,
"data": None,
}
return merged_query_results
@app.post("/query/collection") @app.post("/query/collection")
def query_collection_handler( def query_collection(
form_data: QueryCollectionsForm, form_data: QueryCollectionsForm,
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: results = []
if app.state.ENABLE_RAG_HYBRID_SEARCH:
return query_collection_with_hybrid_search( for collection_name in form_data.collection_names:
collection_names=form_data.collection_names, try:
query=form_data.query, # if you use docker use the model from the environment variable
embedding_function=app.state.EMBEDDING_FUNCTION, collection = CHROMA_CLIENT.get_collection(
k=form_data.k if form_data.k else app.state.TOP_K, name=collection_name,
reranking_function=app.state.sentence_transformer_rf, embedding_function=app.state.sentence_transformer_ef,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
)
else:
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
) )
except Exception as e: result = collection.query(
log.exception(e) query_texts=[form_data.query], n_results=form_data.k
raise HTTPException( )
status_code=status.HTTP_400_BAD_REQUEST, results.append(result)
detail=ERROR_MESSAGES.DEFAULT(e), except:
) pass
return merge_and_sort_query_results(results, form_data.k)
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
try:
loader = YoutubeLoader.from_youtube_url(form_data.url, add_video_info=False)
data = loader.load()
collection_name = form_data.collection_name
if collection_name == "":
collection_name = calculate_sha256_string(form_data.url)[:63]
store_data_in_vector_db(data, collection_name, overwrite=True)
return {
"status": True,
"collection_name": collection_name,
"filename": form_data.url,
}
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@app.post("/web") @app.post("/web")
def store_web(form_data: UrlForm, user=Depends(get_current_user)): def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try: try:
loader = get_web_loader(form_data.url) loader = WebBaseLoader(form_data.url)
data = loader.load() data = loader.load()
collection_name = form_data.collection_name collection_name = form_data.collection_name
if collection_name == "": if collection_name == "":
collection_name = calculate_sha256_string(form_data.url)[:63] collection_name = calculate_sha256_string(form_data.url)[:63]
store_data_in_vector_db(data, collection_name, overwrite=True) store_data_in_vector_db(data, collection_name)
return { return {
"status": True, "status": True,
"collection_name": collection_name, "collection_name": collection_name,
"filename": form_data.url, "filename": form_data.url,
} }
except Exception as e: except Exception as e:
log.exception(e) print(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),
) )
def get_web_loader(url: str):
# Check if the URL is valid
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url)
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
# Extract IP addresses from address information
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
return ipv4_addresses, ipv6_addresses
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.split_documents(data)
if len(docs) > 0:
log.info(f"store_data_in_vector_db {docs}")
return store_docs_in_vector_db(docs, collection_name, overwrite), None
else:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
def store_text_in_vector_db(
text, metadata, collection_name, overwrite: bool = False
) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.create_documents([text], metadatas=[metadata])
return store_docs_in_vector_db(docs, collection_name, overwrite)
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
log.info(f"store_docs_in_vector_db {docs} {collection_name}")
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
try:
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
if collection_name == collection.name:
log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(name=collection_name)
embedding_func = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
embeddings = embedding_func(embedding_texts)
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
embeddings=embeddings,
documents=texts,
):
collection.add(*batch)
return True
except Exception as e:
log.exception(e)
if e.__class__.__name__ == "UniqueConstraintError":
return True
return False
def get_loader(filename: str, file_content_type: str, file_path: str): def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower() file_ext = filename.split(".")[-1].lower()
known_type = True known_type = True
@ -660,15 +401,13 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
] ]
if file_ext == "pdf": if file_ext == "pdf":
loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES) loader = PyPDFLoader(file_path)
elif file_ext == "csv": elif file_ext == "csv":
loader = CSVLoader(file_path) loader = CSVLoader(file_path)
elif file_ext == "rst": elif file_ext == "rst":
loader = UnstructuredRSTLoader(file_path, mode="elements") loader = UnstructuredRSTLoader(file_path, mode="elements")
elif file_ext == "xml": elif file_ext == "xml":
loader = UnstructuredXMLLoader(file_path) loader = UnstructuredXMLLoader(file_path)
elif file_ext in ["htm", "html"]:
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
elif file_ext == "md": elif file_ext == "md":
loader = UnstructuredMarkdownLoader(file_path) loader = UnstructuredMarkdownLoader(file_path)
elif file_content_type == "application/epub+zip": elif file_content_type == "application/epub+zip":
@ -684,12 +423,10 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
] or file_ext in ["xls", "xlsx"]: ] or file_ext in ["xls", "xlsx"]:
loader = UnstructuredExcelLoader(file_path) loader = UnstructuredExcelLoader(file_path)
elif file_ext in known_source_ext or ( elif file_ext in known_source_ext or file_content_type.find("text/") >= 0:
file_content_type and file_content_type.find("text/") >= 0 loader = TextLoader(file_path)
):
loader = TextLoader(file_path, autodetect_encoding=True)
else: else:
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path)
known_type = False known_type = False
return loader, known_type return loader, known_type
@ -703,13 +440,10 @@ def store_doc(
): ):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
log.info(f"file.content_type: {file.content_type}") print(file.content_type)
try: try:
unsanitized_filename = file.filename filename = file.filename
filename = os.path.basename(unsanitized_filename)
file_path = f"{UPLOAD_DIR}/{filename}" file_path = f"{UPLOAD_DIR}/{filename}"
contents = file.file.read() contents = file.file.read()
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(contents) f.write(contents)
@ -720,26 +454,24 @@ def store_doc(
collection_name = calculate_sha256(f)[:63] collection_name = calculate_sha256(f)[:63]
f.close() f.close()
loader, known_type = get_loader(filename, file.content_type, file_path) loader, known_type = get_loader(file.filename, file.content_type, file_path)
data = loader.load() data = loader.load()
result = store_data_in_vector_db(data, collection_name)
try: if result:
result = store_data_in_vector_db(data, collection_name) return {
"status": True,
if result: "collection_name": collection_name,
return { "filename": filename,
"status": True, "known_type": known_type,
"collection_name": collection_name, }
"filename": filename, else:
"known_type": known_type,
}
except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e, detail=ERROR_MESSAGES.DEFAULT(),
) )
except Exception as e: except Exception as e:
log.exception(e) print(e)
if "No pandoc was found" in str(e): if "No pandoc was found" in str(e):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -752,41 +484,10 @@ def store_doc(
) )
class TextRAGForm(BaseModel):
name: str
content: str
collection_name: Optional[str] = None
@app.post("/text")
def store_text(
form_data: TextRAGForm,
user=Depends(get_current_user),
):
collection_name = form_data.collection_name
if collection_name == None:
collection_name = calculate_sha256_string(form_data.content)
result = store_text_in_vector_db(
form_data.content,
metadata={"name": form_data.name, "created_by": user.id},
collection_name=collection_name,
)
if result:
return {"status": True, "collection_name": collection_name}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
@app.get("/scan") @app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)): def scan_docs_dir(user=Depends(get_admin_user)):
for path in Path(DOCS_DIR).rglob("./**/*"): try:
try: for path in Path(DOCS_DIR).rglob("./**/*"):
if path.is_file() and not path.name.startswith("."): if path.is_file() and not path.name.startswith("."):
tags = extract_folders_after_data_docs(path) tags = extract_folders_after_data_docs(path)
filename = path.name filename = path.name
@ -801,45 +502,41 @@ def scan_docs_dir(user=Depends(get_admin_user)):
) )
data = loader.load() data = loader.load()
try: result = store_data_in_vector_db(data, collection_name)
result = store_data_in_vector_db(data, collection_name)
if result: if result:
sanitized_filename = sanitize_filename(filename) sanitized_filename = sanitize_filename(filename)
doc = Documents.get_doc_by_name(sanitized_filename) doc = Documents.get_doc_by_name(sanitized_filename)
if doc == None: if doc == None:
doc = Documents.insert_new_doc( doc = Documents.insert_new_doc(
user.id, user.id,
DocumentForm( DocumentForm(
**{ **{
"name": sanitized_filename, "name": sanitized_filename,
"title": filename, "title": filename,
"collection_name": collection_name, "collection_name": collection_name,
"filename": filename, "filename": filename,
"content": ( "content": (
json.dumps( json.dumps(
{ {
"tags": list( "tags": list(
map( map(
lambda name: {"name": name}, lambda name: {"name": name},
tags, tags,
)
) )
} )
) }
if len(tags) )
else "{}" if len(tags)
), else "{}"
} ),
), }
) ),
except Exception as e: )
log.exception(e)
pass
except Exception as e: except Exception as e:
log.exception(e) print(e)
return True return True
@ -860,11 +557,11 @@ def reset(user=Depends(get_admin_user)) -> bool:
elif os.path.isdir(file_path): elif os.path.isdir(file_path):
shutil.rmtree(file_path) shutil.rmtree(file_path)
except Exception as e: except Exception as e:
log.error("Failed to delete %s. Reason: %s" % (file_path, e)) print("Failed to delete %s. Reason: %s" % (file_path, e))
try: try:
CHROMA_CLIENT.reset() CHROMA_CLIENT.reset()
except Exception as e: except Exception as e:
log.exception(e) print(e)
return True return True

View file

@ -1,517 +0,0 @@
import os
import logging
import requests
from typing import List
from apps.ollama.main import (
generate_ollama_embeddings,
GenerateEmbeddingsForm,
)
from huggingface_hub import snapshot_download
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import (
ContextualCompressionRetriever,
EnsembleRetriever,
)
from typing import Optional
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def query_doc(
collection_name: str,
query: str,
embedding_function,
k: int,
):
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
query_embeddings = embedding_function(query)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
)
log.info(f"query_doc:result {result}")
return result
except Exception as e:
raise e
def query_doc_with_hybrid_search(
collection_name: str,
query: str,
embedding_function,
k: int,
reranking_function,
r: float,
):
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
documents = collection.get() # get all documents
bm25_retriever = BM25Retriever.from_texts(
texts=documents.get("documents"),
metadatas=documents.get("metadatas"),
)
bm25_retriever.k = k
chroma_retriever = ChromaRetriever(
collection=collection,
embedding_function=embedding_function,
top_n=k,
)
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
)
compressor = RerankCompressor(
embedding_function=embedding_function,
top_n=k,
reranking_function=reranking_function,
r_score=r,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
result = compression_retriever.invoke(query)
result = {
"distances": [[d.metadata.get("score") for d in result]],
"documents": [[d.page_content for d in result]],
"metadatas": [[d.metadata for d in result]],
}
log.info(f"query_doc_with_hybrid_search:result {result}")
return result
except Exception as e:
raise e
def merge_and_sort_query_results(query_results, k, reverse=False):
# Initialize lists to store combined data
combined_distances = []
combined_documents = []
combined_metadatas = []
for data in query_results:
combined_distances.extend(data["distances"][0])
combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
# Create a list of tuples (distance, document, metadata)
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=reverse)
# We don't have anything :-(
if not combined:
sorted_distances = []
sorted_documents = []
sorted_metadatas = []
else:
# Unzip the sorted list
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_documents = list(sorted_documents)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
# Create the output dictionary
result = {
"distances": [sorted_distances],
"documents": [sorted_documents],
"metadatas": [sorted_metadatas],
}
return result
def query_collection(
collection_names: List[str],
query: str,
embedding_function,
k: int,
):
results = []
for collection_name in collection_names:
try:
result = query_doc(
collection_name=collection_name,
query=query,
k=k,
embedding_function=embedding_function,
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
collection_names: List[str],
query: str,
embedding_function,
k: int,
reranking_function,
r: float,
):
results = []
for collection_name in collection_names:
try:
result = query_doc_with_hybrid_search(
collection_name=collection_name,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, k=k, reverse=True)
def rag_template(template: str, context: str, query: str):
template = template.replace("[context]", context)
template = template.replace("[query]", query)
return template
def get_embedding_function(
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
if embedding_engine == "ollama":
func = lambda query: generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
)
elif embedding_engine == "openai":
func = lambda query: generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
def generate_multiple(query, f):
if isinstance(query, list):
return [f(q) for q in query]
else:
return f(query)
return lambda query: generate_multiple(query, func)
def rag_messages(
docs,
messages,
template,
embedding_function,
k,
reranking_function,
r,
hybrid_search,
):
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i]["role"] == "user":
last_user_message_idx = i
break
user_message = messages[last_user_message_idx]
if isinstance(user_message["content"], list):
# Handle list content input
content_type = "list"
query = ""
for content_item in user_message["content"]:
if content_item["type"] == "text":
query = content_item["text"]
break
elif isinstance(user_message["content"], str):
# Handle text content input
content_type = "text"
query = user_message["content"]
else:
# Fallback in case the input does not match expected types
content_type = None
query = ""
extracted_collections = []
relevant_contexts = []
for doc in docs:
context = None
collection = doc.get("collection_name")
if collection:
collection = [collection]
else:
collection = doc.get("collection_names", [])
collection = set(collection).difference(extracted_collections)
if not collection:
log.debug(f"skipping {doc} as it has already been extracted")
continue
try:
if doc["type"] == "text":
context = doc["content"]
else:
if hybrid_search:
context = query_collection_with_hybrid_search(
collection_names=(
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
else:
context = query_collection(
collection_names=(
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
query=query,
embedding_function=embedding_function,
k=k,
)
except Exception as e:
log.exception(e)
context = None
if context:
relevant_contexts.append(context)
extracted_collections.extend(collection)
context_string = ""
for context in relevant_contexts:
try:
if "documents" in context:
items = [item for item in context["documents"][0] if item is not None]
context_string += "\n\n".join(items)
except Exception as e:
log.exception(e)
context_string = context_string.strip()
ra_content = rag_template(
template=template,
context=context_string,
query=query,
)
log.debug(f"ra_content: {ra_content}")
if content_type == "list":
new_content = []
for content_item in user_message["content"]:
if content_item["type"] == "text":
# Update the text item's content with ra_content
new_content.append({"type": "text", "text": ra_content})
else:
# Keep other types of content as they are
new_content.append(content_item)
new_user_message = {**user_message, "content": new_content}
else:
new_user_message = {
**user_message,
"content": ra_content,
}
messages[last_user_message_idx] = new_user_message
return messages
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_model
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
log.debug(f"model: {model}")
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
os.path.exists(model)
or ("\\" in model or model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
return model
elif "/" not in model:
# Set valid repo_id for model short-name
model = "sentence-transformers" + "/" + model
snapshot_kwargs["repo_id"] = model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"model_repo_path: {model_repo_path}")
return model_repo_path
except Exception as e:
log.exception(f"Cannot determine model snapshot path: {e}")
return model
def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
):
try:
r = requests.post(
f"{url}/embeddings",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": text, "model": model},
)
r.raise_for_status()
data = r.json()
if "data" in data:
return data["data"][0]["embedding"]
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
return None
from typing import Any
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
class ChromaRetriever(BaseRetriever):
collection: Any
embedding_function: Any
top_n: int
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
query_embeddings = self.embedding_function(query)
results = self.collection.query(
query_embeddings=[query_embeddings],
n_results=self.top_n,
)
ids = results["ids"][0]
metadatas = results["metadatas"][0]
documents = results["documents"][0]
results = []
for idx in range(len(ids)):
results.append(
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
)
return results
import operator
from typing import Optional, Sequence
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra
from sentence_transformers import util
class RerankCompressor(BaseDocumentCompressor):
embedding_function: Any
top_n: int
reranking_function: Any
r_score: float
class Config:
extra = Extra.forbid
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
reranking = self.reranking_function is not None
if reranking:
scores = self.reranking_function.predict(
[(query, doc.page_content) for doc in documents]
)
else:
query_embedding = self.embedding_function(query)
document_embedding = self.embedding_function(
[doc.page_content for doc in documents]
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
docs_with_scores = list(zip(documents, scores.tolist()))
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score
]
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
metadata["score"] = doc_score
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results

View file

@ -1,23 +1,6 @@
from peewee import * from peewee import *
from peewee_migrate import Router from config import DATA_DIR
from playhouse.db_url import connect
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL
import os
import logging
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
# Check if the file exists DB = SqliteDatabase(f"{DATA_DIR}/ollama.db")
if os.path.exists(f"{DATA_DIR}/ollama.db"): DB.connect()
# Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
log.info("Database migrated from Ollama-WebUI successfully.")
else:
pass
DB = connect(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.")
router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log)
router.run()
DB.connect(reuse_if_open=True)

View file

@ -1,254 +0,0 @@
"""Peewee migrations -- 001_initial_schema.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# We perform different migrations for SQLite and other databases
# This is because SQLite is very loose with enforcing its schema, and trying to migrate other databases like SQLite
# will require per-database SQL queries.
# Instead, we assume that because external DB support was added at a later date, it is safe to assume a newer base
# schema instead of trying to migrate from an older schema.
if isinstance(database, pw.SqliteDatabase):
migrate_sqlite(migrator, database, fake=fake)
else:
migrate_external(migrator, database, fake=fake)
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True)
email = pw.CharField(max_length=255)
password = pw.CharField(max_length=255)
active = pw.BooleanField()
class Meta:
table_name = "auth"
@migrator.create_model
class Chat(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.CharField()
chat = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chat"
@migrator.create_model
class ChatIdTag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
tag_name = pw.CharField(max_length=255)
chat_id = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chatidtag"
@migrator.create_model
class Document(pw.Model):
id = pw.AutoField()
collection_name = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255, unique=True)
title = pw.CharField()
filename = pw.CharField()
content = pw.TextField(null=True)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "document"
@migrator.create_model
class Modelfile(pw.Model):
id = pw.AutoField()
tag_name = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
modelfile = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "modelfile"
@migrator.create_model
class Prompt(pw.Model):
id = pw.AutoField()
command = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.CharField()
content = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "prompt"
@migrator.create_model
class Tag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
data = pw.TextField(null=True)
class Meta:
table_name = "tag"
@migrator.create_model
class User(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
email = pw.CharField(max_length=255)
role = pw.CharField(max_length=255)
profile_image_url = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "user"
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True)
email = pw.CharField(max_length=255)
password = pw.TextField()
active = pw.BooleanField()
class Meta:
table_name = "auth"
@migrator.create_model
class Chat(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.TextField()
chat = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chat"
@migrator.create_model
class ChatIdTag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
tag_name = pw.CharField(max_length=255)
chat_id = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chatidtag"
@migrator.create_model
class Document(pw.Model):
id = pw.AutoField()
collection_name = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255, unique=True)
title = pw.TextField()
filename = pw.TextField()
content = pw.TextField(null=True)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "document"
@migrator.create_model
class Modelfile(pw.Model):
id = pw.AutoField()
tag_name = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
modelfile = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "modelfile"
@migrator.create_model
class Prompt(pw.Model):
id = pw.AutoField()
command = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.TextField()
content = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "prompt"
@migrator.create_model
class Tag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
data = pw.TextField(null=True)
class Meta:
table_name = "tag"
@migrator.create_model
class User(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
email = pw.CharField(max_length=255)
role = pw.CharField(max_length=255)
profile_image_url = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "user"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("user")
migrator.remove_model("tag")
migrator.remove_model("prompt")
migrator.remove_model("modelfile")
migrator.remove_model("document")
migrator.remove_model("chatidtag")
migrator.remove_model("chat")
migrator.remove_model("auth")

View file

@ -1,48 +0,0 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"chat", share_id=pw.CharField(max_length=255, null=True, unique=True)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("chat", "share_id")

View file

@ -1,48 +0,0 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"user", api_key=pw.CharField(max_length=255, null=True, unique=True)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("user", "api_key")

View file

@ -1,46 +0,0 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields("chat", archived=pw.BooleanField(default=False))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("chat", "archived")

View file

@ -1,130 +0,0 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
if isinstance(database, pw.SqliteDatabase):
migrate_sqlite(migrator, database, fake=fake)
else:
migrate_external(migrator, database, fake=fake)
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields(
"chat",
created_at=pw.DateTimeField(null=True), # Allow null for transition
updated_at=pw.DateTimeField(null=True), # Allow null for transition
)
# Populate the new fields from an existing 'timestamp' field
migrator.sql(
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
)
# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("chat", "timestamp")
# Update the fields to be not null now that they are populated
migrator.change_fields(
"chat",
created_at=pw.DateTimeField(null=False),
updated_at=pw.DateTimeField(null=False),
)
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields(
"chat",
created_at=pw.BigIntegerField(null=True), # Allow null for transition
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
)
# Populate the new fields from an existing 'timestamp' field
migrator.sql(
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
)
# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("chat", "timestamp")
# Update the fields to be not null now that they are populated
migrator.change_fields(
"chat",
created_at=pw.BigIntegerField(null=False),
updated_at=pw.BigIntegerField(null=False),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
if isinstance(database, pw.SqliteDatabase):
rollback_sqlite(migrator, database, fake=fake)
else:
rollback_external(migrator, database, fake=fake)
def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True))
# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql("UPDATE chat SET timestamp = created_at")
# Remove the created_at and updated_at fields
migrator.remove_fields("chat", "created_at", "updated_at")
# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False))
def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True))
# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql("UPDATE chat SET timestamp = created_at")
# Remove the created_at and updated_at fields
migrator.remove_fields("chat", "created_at", "updated_at")
# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False))

View file

@ -1,130 +0,0 @@
"""Peewee migrations -- 006_migrate_timestamps_and_charfields.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"document",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"modelfile",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"prompt",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"user",
timestamp=pw.BigIntegerField(),
)
# Alter the tables with varchar to text where necessary
migrator.change_fields(
"auth",
password=pw.TextField(),
)
migrator.change_fields(
"chat",
title=pw.TextField(),
)
migrator.change_fields(
"document",
title=pw.TextField(),
filename=pw.TextField(),
)
migrator.change_fields(
"prompt",
title=pw.TextField(),
)
migrator.change_fields(
"user",
profile_image_url=pw.TextField(),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
if isinstance(database, pw.SqliteDatabase):
# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
timestamp=pw.DateField(),
)
migrator.change_fields(
"document",
timestamp=pw.DateField(),
)
migrator.change_fields(
"modelfile",
timestamp=pw.DateField(),
)
migrator.change_fields(
"prompt",
timestamp=pw.DateField(),
)
migrator.change_fields(
"user",
timestamp=pw.DateField(),
)
migrator.change_fields(
"auth",
password=pw.CharField(max_length=255),
)
migrator.change_fields(
"chat",
title=pw.CharField(),
)
migrator.change_fields(
"document",
title=pw.CharField(),
filename=pw.CharField(),
)
migrator.change_fields(
"prompt",
title=pw.CharField(),
)
migrator.change_fields(
"user",
profile_image_url=pw.CharField(),
)

View file

@ -1,79 +0,0 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields created_at and updated_at to the 'user' table
migrator.add_fields(
"user",
created_at=pw.BigIntegerField(null=True), # Allow null for transition
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
last_active_at=pw.BigIntegerField(null=True), # Allow null for transition
)
# Populate the new fields from an existing 'timestamp' field
migrator.sql(
'UPDATE "user" SET created_at = timestamp, updated_at = timestamp, last_active_at = timestamp WHERE timestamp IS NOT NULL'
)
# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("user", "timestamp")
# Update the fields to be not null now that they are populated
migrator.change_fields(
"user",
created_at=pw.BigIntegerField(null=False),
updated_at=pw.BigIntegerField(null=False),
last_active_at=pw.BigIntegerField(null=False),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("user", timestamp=pw.BigIntegerField(null=True))
# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql('UPDATE "user" SET timestamp = created_at')
# Remove the created_at and updated_at fields
migrator.remove_fields("user", "created_at", "updated_at", "last_active_at")
# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("user", timestamp=pw.BigIntegerField(null=False))

View file

@ -1,21 +0,0 @@
# Database Migrations
This directory contains all the database migrations for the web app.
Migrations are done using the [`peewee-migrate`](https://github.com/klen/peewee_migrate) library.
Migrations are automatically ran at app startup.
## Creating a migration
Have you made a change to the schema of an existing model?
You will need to create a migration file to ensure that existing databases are updated for backwards compatibility.
1. Have a database file (`webui.db`) that has the old schema prior to any of your changes.
2. Make your changes to the models.
3. From the `backend` directory, run the following command:
```bash
pw_migrate create --auto --auto-source apps.web.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME}
```
- `$SQLITE_DB` should be the path to the database file.
- `$MIGRATION_NAME` should be a descriptive name for the migration.
4. The migration file will be created in the `apps/web/internal/migrations` directory.

View file

@ -19,8 +19,6 @@ from config import (
DEFAULT_USER_ROLE, DEFAULT_USER_ROLE,
ENABLE_SIGNUP, ENABLE_SIGNUP,
USER_PERMISSIONS, USER_PERMISSIONS,
WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
) )
app = FastAPI() app = FastAPI()
@ -34,8 +32,7 @@ app.state.DEFAULT_MODELS = DEFAULT_MODELS
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS app.state.USER_PERMISSIONS = USER_PERMISSIONS
app.state.WEBHOOK_URL = WEBHOOK_URL
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,

View file

@ -2,7 +2,6 @@ from pydantic import BaseModel
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
import uuid import uuid
import logging
from peewee import * from peewee import *
from apps.web.models.users import UserModel, Users from apps.web.models.users import UserModel, Users
@ -10,11 +9,6 @@ from utils.utils import verify_password
from apps.web.internal.db import DB from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# DB MODEL # DB MODEL
#################### ####################
@ -23,7 +17,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Auth(Model): class Auth(Model):
id = CharField(unique=True) id = CharField(unique=True)
email = CharField() email = CharField()
password = TextField() password = CharField()
active = BooleanField() active = BooleanField()
class Meta: class Meta:
@ -47,10 +41,6 @@ class Token(BaseModel):
token_type: str token_type: str
class ApiKey(BaseModel):
api_key: Optional[str] = None
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: str id: str
email: str email: str
@ -86,11 +76,6 @@ class SignupForm(BaseModel):
name: str name: str
email: str email: str
password: str password: str
profile_image_url: Optional[str] = "/user.png"
class AddUserForm(SignupForm):
role: Optional[str] = "pending"
class AuthsTable: class AuthsTable:
@ -99,14 +84,9 @@ class AuthsTable:
self.db.create_tables([Auth]) self.db.create_tables([Auth])
def insert_new_auth( def insert_new_auth(
self, self, email: str, password: str, name: str, role: str = "pending"
email: str,
password: str,
name: str,
profile_image_url: str = "/user.png",
role: str = "pending",
) -> Optional[UserModel]: ) -> Optional[UserModel]:
log.info("insert_new_auth") print("insert_new_auth")
id = str(uuid.uuid4()) id = str(uuid.uuid4())
@ -115,7 +95,7 @@ class AuthsTable:
) )
result = Auth.create(**auth.model_dump()) result = Auth.create(**auth.model_dump())
user = Users.insert_new_user(id, name, email, profile_image_url, role) user = Users.insert_new_user(id, name, email, role)
if result and user: if result and user:
return user return user
@ -123,7 +103,7 @@ class AuthsTable:
return None return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}") print("authenticate_user", email)
try: try:
auth = Auth.get(Auth.email == email, Auth.active == True) auth = Auth.get(Auth.email == email, Auth.active == True)
if auth: if auth:
@ -137,28 +117,6 @@ class AuthsTable:
except: except:
return None return None
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}")
# if no api_key, return None
if not api_key:
return None
try:
user = Users.get_user_by_api_key(api_key)
return user if user else None
except:
return False
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}")
try:
auth = Auth.get(Auth.email == email, Auth.active == True)
if auth:
user = Users.get_user_by_id(auth.id)
return user
except:
return None
def update_user_password_by_id(self, id: str, new_password: str) -> bool: def update_user_password_by_id(self, id: str, new_password: str) -> bool:
try: try:
query = Auth.update(password=new_password).where(Auth.id == id) query = Auth.update(password=new_password).where(Auth.id == id)

View file

@ -17,14 +17,9 @@ from apps.web.internal.db import DB
class Chat(Model): class Chat(Model):
id = CharField(unique=True) id = CharField(unique=True)
user_id = CharField() user_id = CharField()
title = TextField() title = CharField()
chat = TextField() # Save Chat JSON as Text chat = TextField() # Save Chat JSON as Text
timestamp = DateField()
created_at = BigIntegerField()
updated_at = BigIntegerField()
share_id = CharField(null=True, unique=True)
archived = BooleanField(default=False)
class Meta: class Meta:
database = DB database = DB
@ -35,12 +30,7 @@ class ChatModel(BaseModel):
user_id: str user_id: str
title: str title: str
chat: str chat: str
timestamp: int # timestamp in epoch
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
share_id: Optional[str] = None
archived: bool = False
#################### ####################
@ -61,17 +51,12 @@ class ChatResponse(BaseModel):
user_id: str user_id: str
title: str title: str
chat: dict chat: dict
updated_at: int # timestamp in epoch timestamp: int # timestamp in epoch
created_at: int # timestamp in epoch
share_id: Optional[str] = None # id of the chat to be shared
archived: bool
class ChatTitleIdResponse(BaseModel): class ChatTitleIdResponse(BaseModel):
id: str id: str
title: str title: str
updated_at: int
created_at: int
class ChatTable: class ChatTable:
@ -89,8 +74,7 @@ class ChatTable:
form_data.chat["title"] if "title" in form_data.chat else "New Chat" form_data.chat["title"] if "title" in form_data.chat else "New Chat"
), ),
"chat": json.dumps(form_data.chat), "chat": json.dumps(form_data.chat),
"created_at": int(time.time()), "timestamp": int(time.time()),
"updated_at": int(time.time()),
} }
) )
@ -102,7 +86,7 @@ class ChatTable:
query = Chat.update( query = Chat.update(
chat=json.dumps(chat), chat=json.dumps(chat),
title=chat["title"] if "title" in chat else "New Chat", title=chat["title"] if "title" in chat else "New Chat",
updated_at=int(time.time()), timestamp=int(time.time()),
).where(Chat.id == id) ).where(Chat.id == id)
query.execute() query.execute()
@ -111,64 +95,12 @@ class ChatTable:
except: except:
return None return None
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
# Get the existing chat to share
chat = Chat.get(Chat.id == chat_id)
# Check if the chat is already shared
if chat.share_id:
return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
# Create a new chat with the same data, but with a new ID
shared_chat = ChatModel(
**{
"id": str(uuid.uuid4()),
"user_id": f"shared-{chat_id}",
"title": chat.title,
"chat": chat.chat,
"created_at": chat.created_at,
"updated_at": int(time.time()),
}
)
shared_result = Chat.create(**shared_chat.model_dump())
# Update the original chat with the share_id
result = (
Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute()
)
return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try:
print("update_shared_chat_by_id")
chat = Chat.get(Chat.id == chat_id)
print(chat)
query = Chat.update(
title=chat.title,
chat=chat.chat,
).where(Chat.id == chat.share_id)
query.execute()
chat = Chat.get(Chat.id == chat.share_id)
return ChatModel(**model_to_dict(chat))
except:
return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try:
query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}")
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
def update_chat_share_id_by_id(
self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]:
try: try:
query = Chat.update( query = Chat.update(
share_id=share_id, chat=json.dumps(chat),
title=chat["title"] if "title" in chat else "New Chat",
timestamp=int(time.time()),
).where(Chat.id == id) ).where(Chat.id == id)
query.execute() query.execute()
@ -177,75 +109,41 @@ class ChatTable:
except: except:
return None return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: def get_chat_lists_by_user_id(
try:
chat = self.get_chat_by_id(id)
query = Chat.update(
archived=(not chat.archived),
).where(Chat.id == id)
query.execute()
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
except:
return None
def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
return [ return [
ChatModel(**model_to_dict(chat)) ChatModel(**model_to_dict(chat))
for chat in Chat.select() for chat in Chat.select()
.where(Chat.archived == True)
.where(Chat.user_id == user_id) .where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc()) .order_by(Chat.timestamp.desc())
# .limit(limit) # .limit(limit)
# .offset(skip) # .offset(skip)
] ]
def get_chat_list_by_user_id( def get_chat_lists_by_chat_ids(
self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == False)
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
def get_chat_list_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50 self, chat_ids: List[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
return [ return [
ChatModel(**model_to_dict(chat)) ChatModel(**model_to_dict(chat))
for chat in Chat.select() for chat in Chat.select()
.where(Chat.archived == False)
.where(Chat.id.in_(chat_ids)) .where(Chat.id.in_(chat_ids))
.order_by(Chat.updated_at.desc()) .order_by(Chat.timestamp.desc())
] ]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]: def get_all_chats(self) -> List[ChatModel]:
try: return [
chat = Chat.get(Chat.id == id) ChatModel(**model_to_dict(chat))
return ChatModel(**model_to_dict(chat)) for chat in Chat.select().order_by(Chat.timestamp.desc())
except: ]
return None
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
try: return [
chat = Chat.get(Chat.share_id == id) ChatModel(**model_to_dict(chat))
for chat in Chat.select()
if chat: .where(Chat.user_id == user_id)
chat = Chat.get(Chat.id == id) .order_by(Chat.timestamp.desc())
return ChatModel(**model_to_dict(chat)) ]
else:
return None
except:
return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try: try:
@ -257,42 +155,20 @@ class ChatTable:
def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
return [ return [
ChatModel(**model_to_dict(chat)) ChatModel(**model_to_dict(chat))
for chat in Chat.select().order_by(Chat.updated_at.desc()) for chat in Chat.select().limit(limit).offset(skip)
# .limit(limit).offset(skip)
] ]
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
]
def delete_chat_by_id(self, id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id))
query.execute() # Remove the rows, return number of rows removed.
return True and self.delete_shared_chat_by_chat_id(id)
except:
return False
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id)) query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
query.execute() # Remove the rows, return number of rows removed. query.execute() # Remove the rows, return number of rows removed.
return True and self.delete_shared_chat_by_chat_id(id) return True
except: except:
return False return False
def delete_chats_by_user_id(self, user_id: str) -> bool: def delete_chats_by_user_id(self, user_id: str) -> bool:
try: try:
self.delete_shared_chats_by_user_id(user_id)
query = Chat.delete().where(Chat.user_id == user_id) query = Chat.delete().where(Chat.user_id == user_id)
query.execute() # Remove the rows, return number of rows removed. query.execute() # Remove the rows, return number of rows removed.
@ -300,19 +176,5 @@ class ChatTable:
except: except:
return False return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try:
shared_chat_ids = [
f"shared-{chat.id}"
for chat in Chat.select().where(Chat.user_id == user_id)
]
query = Chat.delete().where(Chat.user_id << shared_chat_ids)
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Chats = ChatTable(DB) Chats = ChatTable(DB)

View file

@ -3,7 +3,6 @@ from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
import logging
from utils.utils import decode_token from utils.utils import decode_token
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
@ -12,11 +11,6 @@ from apps.web.internal.db import DB
import json import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# Documents DB Schema # Documents DB Schema
#################### ####################
@ -25,11 +19,11 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Document(Model): class Document(Model):
collection_name = CharField(unique=True) collection_name = CharField(unique=True)
name = CharField(unique=True) name = CharField(unique=True)
title = TextField() title = CharField()
filename = TextField() filename = CharField()
content = TextField(null=True) content = TextField(null=True)
user_id = CharField() user_id = CharField()
timestamp = BigIntegerField() timestamp = DateField()
class Meta: class Meta:
database = DB database = DB
@ -124,7 +118,7 @@ class DocumentsTable:
doc = Document.get(Document.name == form_data.name) doc = Document.get(Document.name == form_data.name)
return DocumentModel(**model_to_dict(doc)) return DocumentModel(**model_to_dict(doc))
except Exception as e: except Exception as e:
log.exception(e) print(e)
return None return None
def update_doc_content_by_name( def update_doc_content_by_name(
@ -144,7 +138,7 @@ class DocumentsTable:
doc = Document.get(Document.name == name) doc = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(doc)) return DocumentModel(**model_to_dict(doc))
except Exception as e: except Exception as e:
log.exception(e) print(e)
return None return None
def delete_doc_by_name(self, name: str) -> bool: def delete_doc_by_name(self, name: str) -> bool:

View file

@ -20,7 +20,7 @@ class Modelfile(Model):
tag_name = CharField(unique=True) tag_name = CharField(unique=True)
user_id = CharField() user_id = CharField()
modelfile = TextField() modelfile = TextField()
timestamp = BigIntegerField() timestamp = DateField()
class Meta: class Meta:
database = DB database = DB
@ -64,8 +64,8 @@ class ModelfilesTable:
self.db.create_tables([Modelfile]) self.db.create_tables([Modelfile])
def insert_new_modelfile( def insert_new_modelfile(
self, user_id: str, form_data: ModelfileForm self, user_id: str,
) -> Optional[ModelfileModel]: form_data: ModelfileForm) -> Optional[ModelfileModel]:
if "tagName" in form_data.modelfile: if "tagName" in form_data.modelfile:
modelfile = ModelfileModel( modelfile = ModelfileModel(
**{ **{
@ -73,8 +73,7 @@ class ModelfilesTable:
"tag_name": form_data.modelfile["tagName"], "tag_name": form_data.modelfile["tagName"],
"modelfile": json.dumps(form_data.modelfile), "modelfile": json.dumps(form_data.modelfile),
"timestamp": int(time.time()), "timestamp": int(time.time()),
} })
)
try: try:
result = Modelfile.create(**modelfile.model_dump()) result = Modelfile.create(**modelfile.model_dump())
@ -88,28 +87,29 @@ class ModelfilesTable:
else: else:
return None return None
def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]: def get_modelfile_by_tag_name(self,
tag_name: str) -> Optional[ModelfileModel]:
try: try:
modelfile = Modelfile.get(Modelfile.tag_name == tag_name) modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile)) return ModelfileModel(**model_to_dict(modelfile))
except: except:
return None return None
def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]: def get_modelfiles(self,
skip: int = 0,
limit: int = 50) -> List[ModelfileResponse]:
return [ return [
ModelfileResponse( ModelfileResponse(
**{ **{
**model_to_dict(modelfile), **model_to_dict(modelfile),
"modelfile": json.loads(modelfile.modelfile), "modelfile":
} json.loads(modelfile.modelfile),
) }) for modelfile in Modelfile.select()
for modelfile in Modelfile.select()
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
] ]
def update_modelfile_by_tag_name( def update_modelfile_by_tag_name(
self, tag_name: str, modelfile: dict self, tag_name: str, modelfile: dict) -> Optional[ModelfileModel]:
) -> Optional[ModelfileModel]:
try: try:
query = Modelfile.update( query = Modelfile.update(
modelfile=json.dumps(modelfile), modelfile=json.dumps(modelfile),

View file

@ -19,9 +19,9 @@ import json
class Prompt(Model): class Prompt(Model):
command = CharField(unique=True) command = CharField(unique=True)
user_id = CharField() user_id = CharField()
title = TextField() title = CharField()
content = TextField() content = TextField()
timestamp = BigIntegerField() timestamp = DateField()
class Meta: class Meta:
database = DB database = DB
@ -52,9 +52,8 @@ class PromptsTable:
self.db = db self.db = db
self.db.create_tables([Prompt]) self.db.create_tables([Prompt])
def insert_new_prompt( def insert_new_prompt(self, user_id: str,
self, user_id: str, form_data: PromptForm form_data: PromptForm) -> Optional[PromptModel]:
) -> Optional[PromptModel]:
prompt = PromptModel( prompt = PromptModel(
**{ **{
"user_id": user_id, "user_id": user_id,
@ -62,8 +61,7 @@ class PromptsTable:
"title": form_data.title, "title": form_data.title,
"content": form_data.content, "content": form_data.content,
"timestamp": int(time.time()), "timestamp": int(time.time()),
} })
)
try: try:
result = Prompt.create(**prompt.model_dump()) result = Prompt.create(**prompt.model_dump())
@ -83,14 +81,13 @@ class PromptsTable:
def get_prompts(self) -> List[PromptModel]: def get_prompts(self) -> List[PromptModel]:
return [ return [
PromptModel(**model_to_dict(prompt)) PromptModel(**model_to_dict(prompt)) for prompt in Prompt.select()
for prompt in Prompt.select()
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
] ]
def update_prompt_by_command( def update_prompt_by_command(
self, command: str, form_data: PromptForm self, command: str,
) -> Optional[PromptModel]: form_data: PromptForm) -> Optional[PromptModel]:
try: try:
query = Prompt.update( query = Prompt.update(
title=form_data.title, title=form_data.title,

View file

@ -6,15 +6,9 @@ from playhouse.shortcuts import model_to_dict
import json import json
import uuid import uuid
import time import time
import logging
from apps.web.internal.db import DB from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# Tag DB Schema # Tag DB Schema
#################### ####################
@ -35,7 +29,7 @@ class ChatIdTag(Model):
tag_name = CharField() tag_name = CharField()
chat_id = CharField() chat_id = CharField()
user_id = CharField() user_id = CharField()
timestamp = BigIntegerField() timestamp = DateField()
class Meta: class Meta:
database = DB database = DB
@ -136,9 +130,7 @@ class TagTable:
return [ return [
TagModel(**model_to_dict(tag)) TagModel(**model_to_dict(tag))
for tag in Tag.select() for tag in Tag.select().where(Tag.name.in_(tag_names))
.where(Tag.user_id == user_id)
.where(Tag.name.in_(tag_names))
] ]
def get_tags_by_chat_id_and_user_id( def get_tags_by_chat_id_and_user_id(
@ -153,9 +145,7 @@ class TagTable:
return [ return [
TagModel(**model_to_dict(tag)) TagModel(**model_to_dict(tag))
for tag in Tag.select() for tag in Tag.select().where(Tag.name.in_(tag_names))
.where(Tag.user_id == user_id)
.where(Tag.name.in_(tag_names))
] ]
def get_chat_ids_by_tag_name_and_user_id( def get_chat_ids_by_tag_name_and_user_id(
@ -177,27 +167,6 @@ class TagTable:
.count() .count()
) )
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try:
query = ChatIdTag.delete().where(
(ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)
)
res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}")
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0:
# Remove tag item from Tag col as well
query = Tag.delete().where(
(Tag.name == tag_name) & (Tag.user_id == user_id)
)
query.execute() # Remove the rows, return number of rows removed.
return True
except Exception as e:
log.error(f"delete_tag: {e}")
return False
def delete_tag_by_tag_name_and_chat_id_and_user_id( def delete_tag_by_tag_name_and_chat_id_and_user_id(
self, tag_name: str, chat_id: str, user_id: str self, tag_name: str, chat_id: str, user_id: str
) -> bool: ) -> bool:
@ -208,7 +177,7 @@ class TagTable:
& (ChatIdTag.user_id == user_id) & (ChatIdTag.user_id == user_id)
) )
res = query.execute() # Remove the rows, return number of rows removed. res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}") print(res)
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0: if tag_count == 0:
@ -220,7 +189,7 @@ class TagTable:
return True return True
except Exception as e: except Exception as e:
log.error(f"delete_tag: {e}") print("delete_tag", e)
return False return False
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool: def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:

View file

@ -18,13 +18,8 @@ class User(Model):
name = CharField() name = CharField()
email = CharField() email = CharField()
role = CharField() role = CharField()
profile_image_url = TextField() profile_image_url = CharField()
timestamp = DateField()
last_active_at = BigIntegerField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
api_key = CharField(null=True, unique=True)
class Meta: class Meta:
database = DB database = DB
@ -35,13 +30,8 @@ class UserModel(BaseModel):
name: str name: str
email: str email: str
role: str = "pending" role: str = "pending"
profile_image_url: str profile_image_url: str = "/user.png"
timestamp: int # timestamp in epoch
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
api_key: Optional[str] = None
#################### ####################
@ -67,12 +57,7 @@ class UsersTable:
self.db.create_tables([User]) self.db.create_tables([User])
def insert_new_user( def insert_new_user(
self, self, id: str, name: str, email: str, role: str = "pending"
id: str,
name: str,
email: str,
profile_image_url: str = "/user.png",
role: str = "pending",
) -> Optional[UserModel]: ) -> Optional[UserModel]:
user = UserModel( user = UserModel(
**{ **{
@ -80,10 +65,8 @@ class UsersTable:
"name": name, "name": name,
"email": email, "email": email,
"role": role, "role": role,
"profile_image_url": profile_image_url, "profile_image_url": "/user.png",
"last_active_at": int(time.time()), "timestamp": int(time.time()),
"created_at": int(time.time()),
"updated_at": int(time.time()),
} }
) )
result = User.create(**user.model_dump()) result = User.create(**user.model_dump())
@ -99,13 +82,6 @@ class UsersTable:
except: except:
return None return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try:
user = User.get(User.api_key == api_key)
return UserModel(**model_to_dict(user))
except:
return None
def get_user_by_email(self, email: str) -> Optional[UserModel]: def get_user_by_email(self, email: str) -> Optional[UserModel]:
try: try:
user = User.get(User.email == email) user = User.get(User.email == email)
@ -123,13 +99,6 @@ class UsersTable:
def get_num_users(self) -> Optional[int]: def get_num_users(self) -> Optional[int]:
return User.select().count() return User.select().count()
def get_first_user(self) -> UserModel:
try:
user = User.select().order_by(User.created_at).first()
return UserModel(**model_to_dict(user))
except:
return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try: try:
query = User.update(role=role).where(User.id == id) query = User.update(role=role).where(User.id == id)
@ -154,16 +123,6 @@ class UsersTable:
except: except:
return None return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try:
query = User.update(last_active_at=int(time.time())).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try: try:
query = User.update(**updated).where(User.id == id) query = User.update(**updated).where(User.id == id)
@ -190,21 +149,5 @@ class UsersTable:
except: except:
return False return False
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try:
query = User.update(api_key=api_key).where(User.id == id)
result = query.execute()
return True if result == 1 else False
except:
return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try:
user = User.get(User.id == id)
return user.api_key
except:
return None
Users = UsersTable(DB) Users = UsersTable(DB)

View file

@ -1,25 +1,22 @@
import logging from fastapi import Response, Request
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union
from fastapi import Request, UploadFile, File from fastapi import APIRouter, status
from fastapi import Depends, HTTPException, status
from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import re import time
import uuid import uuid
import csv import re
from apps.web.models.auths import ( from apps.web.models.auths import (
SigninForm, SigninForm,
SignupForm, SignupForm,
AddUserForm,
UpdateProfileForm, UpdateProfileForm,
UpdatePasswordForm, UpdatePasswordForm,
UserResponse, UserResponse,
SigninResponse, SigninResponse,
Auths, Auths,
ApiKey,
) )
from apps.web.models.users import Users from apps.web.models.users import Users
@ -28,12 +25,9 @@ from utils.utils import (
get_current_user, get_current_user,
get_admin_user, get_admin_user,
create_token, create_token,
create_api_key,
) )
from utils.misc import parse_duration, validate_email_format from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook from constants import ERROR_MESSAGES
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import WEBUI_AUTH_TRUSTED_EMAIL_HEADER
router = APIRouter() router = APIRouter()
@ -84,8 +78,6 @@ async def update_profile(
async def update_password( async def update_password(
form_data: UpdatePasswordForm, session_user=Depends(get_current_user) form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
): ):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
if session_user: if session_user:
user = Auths.authenticate_user(session_user.email, form_data.password) user = Auths.authenticate_user(session_user.email, form_data.password)
@ -105,22 +97,7 @@ async def update_password(
@router.post("/signin", response_model=SigninResponse) @router.post("/signin", response_model=SigninResponse)
async def signin(request: Request, form_data: SigninForm): async def signin(request: Request, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
if not Users.get_user_by_email(trusted_email.lower()):
await signup(
request,
SignupForm(
email=trusted_email, password=str(uuid.uuid4()), name=trusted_email
),
)
user = Auths.authenticate_user_by_trusted_header(trusted_email)
else:
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if user: if user:
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
@ -168,11 +145,7 @@ async def signup(request: Request, form_data: SignupForm):
) )
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
form_data.email.lower(), form_data.email.lower(), hashed, form_data.name, role
hashed,
form_data.name,
form_data.profile_image_url,
role,
) )
if user: if user:
@ -182,62 +155,6 @@ async def signup(request: Request, form_data: SignupForm):
) )
# response.set_cookie(key='token', value=token, httponly=True) # response.set_cookie(key='token', value=token, httponly=True)
if request.app.state.WEBHOOK_URL:
post_webhook(
request.app.state.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
return {
"token": token,
"token_type": "Bearer",
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
"profile_image_url": user.profile_image_url,
}
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
############################
# AddUser
############################
@router.post("/add", response_model=SigninResponse)
async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
if not validate_email_format(form_data.email.lower()):
raise HTTPException(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
)
if Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try:
print(form_data)
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(
form_data.email.lower(),
hashed,
form_data.name,
form_data.profile_image_url,
form_data.role,
)
if user:
token = create_token(data={"id": user.id})
return { return {
"token": token, "token": token,
"token_type": "Bearer", "token_type": "Bearer",
@ -320,40 +237,3 @@ async def update_token_expires_duration(
return request.app.state.JWT_EXPIRES_IN return request.app.state.JWT_EXPIRES_IN
else: else:
return request.app.state.JWT_EXPIRES_IN return request.app.state.JWT_EXPIRES_IN
############################
# API Key
############################
# create api key
@router.post("/api_key", response_model=ApiKey)
async def create_api_key_(user=Depends(get_current_user)):
api_key = create_api_key()
success = Users.update_user_api_key_by_id(user.id, api_key)
if success:
return {
"api_key": api_key,
}
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
# delete api key
@router.delete("/api_key", response_model=bool)
async def delete_api_key(user=Depends(get_current_user)):
success = Users.update_user_api_key_by_id(user.id, None)
return success
# get api key
@router.get("/api_key", response_model=ApiKey)
async def get_api_key(user=Depends(get_current_user)):
api_key = Users.get_user_api_key_by_id(user.id)
if api_key:
return {
"api_key": api_key,
}
else:
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

View file

@ -5,7 +5,6 @@ from utils.utils import get_current_user, get_admin_user
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
import logging
from apps.web.models.users import Users from apps.web.models.users import Users
from apps.web.models.chats import ( from apps.web.models.chats import (
@ -28,81 +27,30 @@ from apps.web.models.tags import (
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter() router = APIRouter()
############################
# GetChatList
############################
@router.get("/", response_model=List[ChatTitleIdResponse])
@router.get("/list", response_model=List[ChatTitleIdResponse])
async def get_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
return Chats.get_chat_list_by_user_id(user.id, skip, limit)
############################
# DeleteAllChats
############################
@router.delete("/", response_model=bool)
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
if (
user.role == "user"
and not request.app.state.USER_PERMISSIONS["chat"]["deletion"]
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Chats.delete_chats_by_user_id(user.id)
return result
############################
# GetUserChatList
############################
@router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id(
user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50
):
return Chats.get_chat_list_by_user_id(user_id, skip, limit)
############################
# GetArchivedChats
############################
@router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
############################ ############################
# GetChats # GetChats
############################ ############################
@router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
############################
# GetAllChats
############################
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_current_user)): async def get_all_user_chats(user=Depends(get_current_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_chats_by_user_id(user.id) for chat in Chats.get_all_chats_by_user_id(user.id)
] ]
@ -113,14 +61,9 @@ async def get_user_chats(user=Depends(get_current_user)):
@router.get("/all/db", response_model=List[ChatResponse]) @router.get("/all/db", response_model=List[ChatResponse])
async def get_all_user_chats_in_db(user=Depends(get_admin_user)): async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
if not ENABLE_ADMIN_EXPORT:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_chats() for chat in Chats.get_all_chats()
] ]
@ -135,12 +78,48 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
chat = Chats.insert_new_chat(user.id, form_data) chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
except Exception as e: except Exception as e:
log.exception(e) print(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
) )
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatsByTags
############################
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
async def get_user_chats_by_tag_name(
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
]
print(chat_ids)
return Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit)
############################ ############################
# GetChatById # GetChatById
############################ ############################
@ -188,154 +167,17 @@ async def update_chat_by_id(
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)): async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
if user.role == "admin": if (
result = Chats.delete_chat_by_id(id) user.role == "user"
return result and not request.app.state.USER_PERMISSIONS["chat"]["deletion"]
else: ):
if not request.app.state.USER_PERMISSIONS["chat"]["deletion"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result
############################
# ArchiveChat
############################
@router.get("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.toggle_chat_archive_by_id(id)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# ShareChatById
############################
@router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
return ChatResponse(
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
)
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
if not shared_chat:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
return ChatResponse(
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
)
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
############################ return result
# DeletedSharedChatById
############################
@router.delete("/{id}/share", response_model=Optional[bool])
async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
if not chat.share_id:
return False
result = Chats.delete_shared_chat_by_chat_id(id)
update_result = Chats.update_chat_share_id_by_id(id, None)
return result and update_result != None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# GetSharedChatById
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatsByTags
############################
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, skip, limit)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(tag_name, user.id)
return chats
############################ ############################
@ -418,3 +260,14 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
) )
############################
# DeleteAllChats
############################
@router.delete("/", response_model=bool)
async def delete_all_user_chats(user=Depends(get_current_user)):
result = Chats.delete_chats_by_user_id(user.id)
return result

View file

@ -10,12 +10,7 @@ import uuid
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import ( from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
get_password_hash,
get_current_user,
get_admin_user,
create_token,
)
from utils.misc import get_gravatar_url, validate_email_format from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -48,6 +43,7 @@ async def set_global_default_models(
return request.app.state.DEFAULT_MODELS return request.app.state.DEFAULT_MODELS
@router.post("/default/suggestions", response_model=List[PromptSuggestion]) @router.post("/default/suggestions", response_model=List[PromptSuggestion])
async def set_global_default_suggestions( async def set_global_default_suggestions(
request: Request, request: Request,

View file

@ -24,9 +24,9 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse]) @router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles( async def get_modelfiles(skip: int = 0,
skip: int = 0, limit: int = 50, user=Depends(get_current_user) limit: int = 50,
): user=Depends(get_current_user)):
return Modelfiles.get_modelfiles(skip, limit) return Modelfiles.get_modelfiles(skip, limit)
@ -36,16 +36,17 @@ async def get_modelfiles(
@router.post("/create", response_model=Optional[ModelfileResponse]) @router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)): async def create_new_modelfile(form_data: ModelfileForm,
user=Depends(get_admin_user)):
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile: if modelfile:
return ModelfileResponse( return ModelfileResponse(
**{ **{
**modelfile.model_dump(), **modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile), "modelfile":
} json.loads(modelfile.modelfile),
) })
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -59,18 +60,17 @@ async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_
@router.post("/", response_model=Optional[ModelfileResponse]) @router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name( async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
form_data: ModelfileTagNameForm, user=Depends(get_current_user) user=Depends(get_current_user)):
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile: if modelfile:
return ModelfileResponse( return ModelfileResponse(
**{ **{
**modelfile.model_dump(), **modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile), "modelfile":
} json.loads(modelfile.modelfile),
) })
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -84,9 +84,8 @@ async def get_modelfile_by_tag_name(
@router.post("/update", response_model=Optional[ModelfileResponse]) @router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name( async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
form_data: ModelfileUpdateForm, user=Depends(get_admin_user) user=Depends(get_admin_user)):
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile: if modelfile:
updated_modelfile = { updated_modelfile = {
@ -95,15 +94,14 @@ async def update_modelfile_by_tag_name(
} }
modelfile = Modelfiles.update_modelfile_by_tag_name( modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile form_data.tag_name, updated_modelfile)
)
return ModelfileResponse( return ModelfileResponse(
**{ **{
**modelfile.model_dump(), **modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile), "modelfile":
} json.loads(modelfile.modelfile),
) })
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -117,8 +115,7 @@ async def update_modelfile_by_tag_name(
@router.delete("/delete", response_model=bool) @router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name( async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
form_data: ModelfileTagNameForm, user=Depends(get_admin_user) user=Depends(get_admin_user)):
):
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name) result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result return result

View file

@ -7,7 +7,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import time import time
import uuid import uuid
import logging
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths from apps.web.models.auths import Auths
@ -15,11 +14,6 @@ from apps.web.models.auths import Auths
from utils.utils import get_current_user, get_password_hash, get_admin_user from utils.utils import get_current_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter() router = APIRouter()
############################ ############################
@ -58,7 +52,7 @@ async def update_user_permissions(
@router.post("/update/role", response_model=Optional[UserModel]) @router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)): async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
if user.id != form_data.id and form_data.id != Users.get_first_user().id: if user.id != form_data.id:
return Users.update_user_role_by_id(form_data.id, form_data.role) return Users.update_user_role_by_id(form_data.id, form_data.role)
raise HTTPException( raise HTTPException(
@ -89,7 +83,7 @@ async def update_user_by_id(
if form_data.password: if form_data.password:
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
log.debug(f"hashed: {hashed}") print(hashed)
Auths.update_user_password_by_id(user_id, hashed) Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(user_id, form_data.email.lower()) Auths.update_email_by_id(user_id, form_data.email.lower())

View file

@ -1,109 +1,174 @@
from fastapi import APIRouter, UploadFile, File, Response from fastapi import APIRouter, UploadFile, File, BackgroundTasks
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from peewee import SqliteDatabase from starlette.responses import StreamingResponse
from starlette.responses import StreamingResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
import requests
import os
import aiohttp
import json
from fpdf import FPDF
import markdown
from apps.web.internal.db import DB
from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url from utils.misc import calculate_sha256, get_gravatar_url
from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR, ENABLE_ADMIN_EXPORT from config import OLLAMA_API_BASE_URL, DATA_DIR, UPLOAD_DIR
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from typing import List
router = APIRouter() router = APIRouter()
class UploadBlobForm(BaseModel):
filename: str
from urllib.parse import urlparse
def parse_huggingface_url(hf_url):
try:
# Parse the URL
parsed_url = urlparse(hf_url)
# Get the path and split it into components
path_components = parsed_url.path.split("/")
# Extract the desired output
user_repo = "/".join(path_components[1:3])
model_file = path_components[-1]
return model_file
except ValueError:
return None
async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
done = False
if os.path.exists(file_path):
current_size = os.path.getsize(file_path)
else:
current_size = 0
headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
with open(file_path, "ab+") as file:
async for data in response.content.iter_chunked(chunk_size):
current_size += len(data)
file.write(data)
done = current_size == total_size
progress = round((current_size / total_size) * 100, 2)
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
if done:
file.seek(0)
hashed = calculate_sha256(file)
file.seek(0)
url = f"{OLLAMA_API_BASE_URL}/blobs/sha256:{hashed}"
response = requests.post(url, data=file)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file_name,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise "Ollama: Could not create blob, Please try again."
@router.get("/download")
async def download(
url: str,
):
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
file_name = parse_huggingface_url(url)
if file_name:
file_path = f"{UPLOAD_DIR}/{file_name}"
return StreamingResponse(
download_file_stream(url, file_path, file_name),
media_type="text/event-stream",
)
else:
return None
@router.post("/upload")
def upload(file: UploadFile = File(...)):
file_path = f"{UPLOAD_DIR}/{file.filename}"
# Save file in chunks
with open(file_path, "wb+") as f:
for chunk in file.file:
f.write(chunk)
def file_process_stream():
total_size = os.path.getsize(file_path)
chunk_size = 1024 * 1024
try:
with open(file_path, "rb") as f:
total = 0
done = False
while not done:
chunk = f.read(chunk_size)
if not chunk:
done = True
continue
total += len(chunk)
progress = round((total / total_size) * 100, 2)
res = {
"progress": progress,
"total": total_size,
"completed": total,
}
yield f"data: {json.dumps(res)}\n\n"
if done:
f.seek(0)
hashed = calculate_sha256(f)
f.seek(0)
url = f"{OLLAMA_API_BASE_URL}/blobs/sha256:{hashed}"
response = requests.post(url, data=f)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file.filename,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise Exception(
"Ollama: Could not create blob, Please try again."
)
except Exception as e:
res = {"error": str(e)}
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
@router.get("/gravatar") @router.get("/gravatar")
async def get_gravatar( async def get_gravatar(
email: str, email: str,
): ):
return get_gravatar_url(email) return get_gravatar_url(email)
class MarkdownForm(BaseModel):
md: str
@router.post("/markdown")
async def get_html_from_markdown(
form_data: MarkdownForm,
):
return {"html": markdown.markdown(form_data.md)}
class ChatForm(BaseModel):
title: str
messages: List[dict]
@router.post("/pdf")
async def download_chat_as_pdf(
form_data: ChatForm,
):
pdf = FPDF()
pdf.add_page()
STATIC_DIR = "./static"
FONTS_DIR = f"{STATIC_DIR}/fonts"
pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf")
pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf")
pdf.add_font("NotoSans", "i", f"{FONTS_DIR}/NotoSans-Italic.ttf")
pdf.add_font("NotoSansKR", "", f"{FONTS_DIR}/NotoSansKR-Regular.ttf")
pdf.add_font("NotoSansJP", "", f"{FONTS_DIR}/NotoSansJP-Regular.ttf")
pdf.set_font("NotoSans", size=12)
pdf.set_fallback_fonts(["NotoSansKR", "NotoSansJP"])
pdf.set_auto_page_break(auto=True, margin=15)
# Adjust the effective page width for multi_cell
effective_page_width = (
pdf.w - 2 * pdf.l_margin - 10
) # Subtracted an additional 10 for extra padding
# Add chat messages
for message in form_data.messages:
role = message["role"]
content = message["content"]
pdf.set_font("NotoSans", "B", size=14) # Bold for the role
pdf.multi_cell(effective_page_width, 10, f"{role.upper()}", 0, "L")
pdf.ln(1) # Extra space between messages
pdf.set_font("NotoSans", size=10) # Regular for content
pdf.multi_cell(effective_page_width, 6, content, 0, "L")
pdf.ln(1.5) # Extra space between messages
# Save the pdf with name .pdf
pdf_bytes = pdf.output()
return Response(
content=bytes(pdf_bytes),
media_type="application/pdf",
headers={"Content-Disposition": f"attachment;filename=chat.pdf"},
)
@router.get("/db/download")
async def download_db(user=Depends(get_admin_user)):
if not ENABLE_ADMIN_EXPORT:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if not isinstance(DB, SqliteDatabase):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
)
return FileResponse(
DB.database,
media_type="application/octet-stream",
filename="webui.db",
)

View file

@ -1,80 +1,22 @@
import os import os
import sys
import logging
import chromadb import chromadb
from chromadb import Settings from chromadb import Settings
from secrets import token_bytes
from base64 import b64encode from base64 import b64encode
from bs4 import BeautifulSoup from constants import ERROR_MESSAGES
from pathlib import Path from pathlib import Path
import json import json
import yaml
import markdown import markdown
import requests from bs4 import BeautifulSoup
import shutil
from secrets import token_bytes
from constants import ERROR_MESSAGES
####################################
# LOGGING
####################################
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
if GLOBAL_LOG_LEVEL in log_levels:
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
else:
GLOBAL_LOG_LEVEL = "INFO"
log = logging.getLogger(__name__)
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
log_sources = [
"AUDIO",
"COMFYUI",
"CONFIG",
"DB",
"IMAGES",
"LITELLM",
"MAIN",
"MODELS",
"OLLAMA",
"OPENAI",
"RAG",
"WEBHOOK",
]
SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in log_levels:
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
####################################
# Load .env file
####################################
try: try:
from dotenv import load_dotenv, find_dotenv from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv("../.env")) load_dotenv(find_dotenv("../.env"))
except ImportError: except ImportError:
log.warning("dotenv not installed, skipping...") print("dotenv not installed, skipping...")
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
#################################### ####################################
# ENV (dev,test,prod) # ENV (dev,test,prod)
@ -82,6 +24,7 @@ WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
ENV = os.environ.get("ENV", "dev") ENV = os.environ.get("ENV", "dev")
try: try:
with open(f"../package.json", "r") as f: with open(f"../package.json", "r") as f:
PACKAGE_DATA = json.load(f) PACKAGE_DATA = json.load(f)
@ -135,6 +78,8 @@ for version in soup.find_all("h2"):
# Find the next sibling that is a h3 tag (section title) # Find the next sibling that is a h3 tag (section title)
current = version.find_next_sibling() current = version.find_next_sibling()
print(current)
while current and current.name != "h2": while current and current.name != "h2":
if current.name == "h3": if current.name == "h3":
section_title = current.get_text().lower() # e.g., "added", "fixed" section_title = current.get_text().lower() # e.g., "added", "fixed"
@ -162,48 +107,6 @@ try:
except: except:
CONFIG_DATA = {} CONFIG_DATA = {}
####################################
# Static DIR
####################################
STATIC_DIR = str(Path(os.getenv("STATIC_DIR", "./static")).resolve())
frontend_favicon = f"{FRONTEND_BUILD_DIR}/favicon.png"
if os.path.exists(frontend_favicon):
shutil.copyfile(frontend_favicon, f"{STATIC_DIR}/favicon.png")
else:
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
####################################
# CUSTOM_NAME
####################################
CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "")
if CUSTOM_NAME:
try:
r = requests.get(f"https://api.openwebui.com/api/v1/custom/{CUSTOM_NAME}")
data = r.json()
if r.ok:
if "logo" in data:
WEBUI_FAVICON_URL = url = (
f"https://api.openwebui.com{data['logo']}"
if data["logo"][0] == "/"
else data["logo"]
)
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(f"{STATIC_DIR}/favicon.png", "wb") as f:
r.raw.decode_content = True
shutil.copyfileobj(r.raw, f)
WEBUI_NAME = data["name"]
except Exception as e:
log.exception(e)
pass
#################################### ####################################
# File Upload DIR # File Upload DIR
#################################### ####################################
@ -220,83 +123,31 @@ CACHE_DIR = f"{DATA_DIR}/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Functions DIR
####################################
FUNCTIONS_DIR = f"{DATA_DIR}/functions"
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# Docs DIR # Docs DIR
#################################### ####################################
DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs") DOCS_DIR = f"{DATA_DIR}/docs"
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True) Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# LITELLM_CONFIG # OLLAMA_API_BASE_URL
####################################
def create_config_file(file_path):
directory = os.path.dirname(file_path)
# Check if directory exists, if not, create it
if not os.path.exists(directory):
os.makedirs(directory)
# Data to write into the YAML file
config_data = {
"general_settings": {},
"litellm_settings": {},
"model_list": [],
"router_settings": {},
}
# Write data to YAML file
with open(file_path, "w") as file:
yaml.dump(config_data, file)
LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml"
if not os.path.exists(LITELLM_CONFIG_PATH):
log.info("Config file doesn't exist. Creating...")
create_config_file(LITELLM_CONFIG_PATH)
log.info("Config file created successfully.")
####################################
# OLLAMA_BASE_URL
#################################### ####################################
OLLAMA_API_BASE_URL = os.environ.get( OLLAMA_API_BASE_URL = os.environ.get(
"OLLAMA_API_BASE_URL", "http://localhost:11434/api" "OLLAMA_API_BASE_URL", "http://localhost:11434/api"
) )
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
K8S_FLAG = os.environ.get("K8S_FLAG", "")
USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false")
if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "":
OLLAMA_BASE_URL = (
OLLAMA_API_BASE_URL[:-4]
if OLLAMA_API_BASE_URL.endswith("/api")
else OLLAMA_API_BASE_URL
)
if ENV == "prod": if ENV == "prod":
if OLLAMA_BASE_URL == "/ollama" and not K8S_FLAG: if OLLAMA_API_BASE_URL == "/ollama/api":
if USE_OLLAMA_DOCKER.lower() == "true": OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api"
# if you use all-in-one docker container (Open WebUI + Ollama)
# with the docker build arg USE_OLLAMA=true (--build-arg="USE_OLLAMA=true") this only works with http://localhost:11434
OLLAMA_BASE_URL = "http://localhost:11434"
else:
OLLAMA_BASE_URL = "http://host.docker.internal:11434"
elif K8S_FLAG:
OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434"
OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL
OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
#################################### ####################################
# OPENAI_API # OPENAI_API
@ -305,43 +156,15 @@ OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
if OPENAI_API_BASE_URL == "": if OPENAI_API_BASE_URL == "":
OPENAI_API_BASE_URL = "https://api.openai.com/v1" OPENAI_API_BASE_URL = "https://api.openai.com/v1"
OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")]
OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "")
OPENAI_API_BASE_URLS = (
OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL
)
OPENAI_API_BASE_URLS = [
url.strip() if url != "" else "https://api.openai.com/v1"
for url in OPENAI_API_BASE_URLS.split(";")
]
OPENAI_API_KEY = ""
try:
OPENAI_API_KEY = OPENAI_API_KEYS[
OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
]
except:
pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
#################################### ####################################
# WEBUI # WEBUI
#################################### ####################################
ENABLE_SIGNUP = os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" ENABLE_SIGNUP = os.environ.get("ENABLE_SIGNUP", True)
DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None) DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None)
@ -367,36 +190,13 @@ DEFAULT_PROMPT_SUGGESTIONS = (
"title": ["Show me a code snippet", "of a website's sticky header"], "title": ["Show me a code snippet", "of a website's sticky header"],
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.", "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.",
}, },
{
"title": [
"Explain options trading",
"if I'm familiar with buying and selling stocks",
],
"content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.",
},
{
"title": ["Overcome procrastination", "give me tips"],
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
},
] ]
) )
DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending") DEFAULT_USER_ROLE = "pending"
USER_PERMISSIONS = {"chat": {"deletion": True}}
USER_PERMISSIONS_CHAT_DELETION = (
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
)
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
ENABLE_MODEL_FILTER = os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true"
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "")
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
#################################### ####################################
# WEBUI_VERSION # WEBUI_VERSION
@ -409,9 +209,6 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
#################################### ####################################
WEBUI_AUTH = True WEBUI_AUTH = True
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
#################################### ####################################
# WEBUI_SECRET_KEY # WEBUI_SECRET_KEY
@ -432,87 +229,21 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
#################################### ####################################
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT) # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE) RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "") # device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000")) RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
# Comma-separated list of header=value pairs "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
if CHROMA_HTTP_HEADERS:
CHROMA_HTTP_HEADERS = dict(
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
)
else:
CHROMA_HTTP_HEADERS = None
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5"))
RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0"))
ENABLE_RAG_HYBRID_SEARCH = (
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
) )
CHROMA_CLIENT = chromadb.PersistentClient(
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true"
RAG_EMBEDDING_MODEL = os.environ.get(
"RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
) )
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
)
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
if not RAG_RERANKING_MODEL == "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
RAG_RERANKING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
)
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
if USE_CUDA.lower() == "true":
DEVICE_TYPE = "cuda"
else:
DEVICE_TYPE = "cpu"
if CHROMA_HTTP_HOST != "":
CHROMA_CLIENT = chromadb.HttpClient(
host=CHROMA_HTTP_HOST,
port=CHROMA_HTTP_PORT,
headers=CHROMA_HTTP_HEADERS,
ssl=CHROMA_HTTP_SSL,
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
)
else:
CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
)
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100"))
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context> <context>
[context] [context]
</context> </context>
@ -522,74 +253,20 @@ When answer to user:
- If you don't know when you are not sure, ask for clarification. - If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context. Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question. And answer according to the language of the user's question.
Given the context information, answer the query. Given the context information, answer the query.
Query: [query]""" Query: [query]"""
RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true"
#################################### ####################################
# Transcribe # Transcribe
#################################### ####################################
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base") WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
WHISPER_MODEL_AUTO_UPDATE = (
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
#################################### ####################################
# Images # Images
#################################### ####################################
IMAGE_GENERATION_ENGINE = os.getenv("IMAGE_GENERATION_ENGINE", "")
ENABLE_IMAGE_GENERATION = (
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true"
)
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
IMAGES_OPENAI_API_BASE_URL = os.getenv(
"IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL
)
IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY)
IMAGE_SIZE = os.getenv("IMAGE_SIZE", "512x512")
IMAGE_STEPS = int(os.getenv("IMAGE_STEPS", 50))
IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "")
####################################
# Audio
####################################
AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY)
####################################
# LiteLLM
####################################
ENABLE_LITELLM = os.environ.get("ENABLE_LITELLM", "True").lower() == "true"
LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365"))
if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535:
raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")
####################################
# Database
####################################
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")

View file

@ -3,17 +3,6 @@ from enum import Enum
class MESSAGES(str, Enum): class MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}" DEFAULT = lambda msg="": f"{msg if msg else ''}"
MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully."
MODEL_DELETED = (
lambda model="": f"The model '{model}' has been deleted successfully."
)
class WEBHOOK_MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}"
USER_SIGNUP = lambda username="": (
f"New user signed up: {username}" if username else "New user signed up"
)
class ERROR_MESSAGES(str, Enum): class ERROR_MESSAGES(str, Enum):
@ -24,7 +13,6 @@ class ERROR_MESSAGES(str, Enum):
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now." ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now."
CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance." CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance."
DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot." DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."
EMAIL_MISMATCH = "Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again."
EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew." EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew."
USERNAME_TAKEN = ( USERNAME_TAKEN = (
"Uh-oh! This username is already registered. Please choose another username." "Uh-oh! This username is already registered. Please choose another username."
@ -41,7 +29,6 @@ class ERROR_MESSAGES(str, Enum):
INVALID_PASSWORD = ( INVALID_PASSWORD = (
"The password provided is incorrect. Please check for typos and try again." "The password provided is incorrect. Please check for typos and try again."
) )
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
UNAUTHORIZED = "401 Unauthorized" UNAUTHORIZED = "401 Unauthorized"
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
ACTION_PROHIBITED = ( ACTION_PROHIBITED = (
@ -54,24 +41,9 @@ class ERROR_MESSAGES(str, Enum):
NOT_FOUND = "We could not find what you're looking for :/" NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
MALICIOUS = "Unusual activities detected, please try again in a few minutes." MALICIOUS = "Unusual activities detected, please try again in a few minutes."
PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance." PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance."
INCORRECT_FORMAT = ( INCORRECT_FORMAT = (
lambda err="": f"Invalid format. Please use the correct format{err}" lambda err="": f"Invalid format. Please use the correct format{err if err else ''}"
)
RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
INVALID_URL = (
"Oops! The URL you provided is invalid. Please double-check and try again."
) )

View file

@ -1,36 +1,34 @@
{ {
"version": 0, "ui": {
"ui": { "prompt_suggestions": [
"default_locale": "en-US", {
"prompt_suggestions": [ "title": [
{ "Help me study",
"title": ["Help me study", "vocabulary for a college entrance exam"], "vocabulary for a college entrance exam"
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option." ],
}, "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."
{ },
"title": ["Give me ideas", "for what to do with my kids' art"], {
"content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter." "title": [
}, "Give me ideas",
{ "for what to do with my kids' art"
"title": ["Tell me a fun fact", "about the Roman Empire"], ],
"content": "Tell me a random fun fact about the Roman Empire" "content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."
}, },
{ {
"title": ["Show me a code snippet", "of a website's sticky header"], "title": [
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript." "Tell me a fun fact",
}, "about the Roman Empire"
{ ],
"title": ["Explain options trading", "if I'm familiar with buying and selling stocks"], "content": "Tell me a random fun fact about the Roman Empire"
"content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks." },
}, {
{ "title": [
"title": ["Overcome procrastination", "give me tips"], "Show me a code snippet",
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?" "of a website's sticky header"
}, ],
{ "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript."
"title": ["Grammar check", "rewrite it for better readability "], }
"content": "Check the following sentence for grammar and clarity: \"[sentence]\". Rewrite it for better readability while maintaining its original meaning." ]
} }
] }
}
}

View file

@ -1,4 +0,0 @@
general_settings: {}
litellm_settings: {}
model_list: []
router_settings: {}

0
backend/dev.sh Executable file → Normal file
View file

View file

@ -2,64 +2,26 @@ from bs4 import BeautifulSoup
import json import json
import markdown import markdown
import time import time
import os
import sys
import logging
import aiohttp
import requests
from fastapi import FastAPI, Request, Depends, status
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from apps.ollama.main import app as ollama_app from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app from apps.openai.main import app as openai_app
from apps.litellm.main import (
app as litellm_app,
start_litellm_background,
shutdown_litellm_background,
)
from apps.audio.main import app as audio_app from apps.audio.main import app as audio_app
from apps.functions.main import app as functions_app
from apps.images.main import app as images_app from apps.images.main import app as images_app
from apps.rag.main import app as rag_app from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app from apps.web.main import app as webui_app
import asyncio from config import ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
from pydantic import BaseModel
from typing import List
from utils.utils import get_admin_user
from apps.rag.utils import rag_messages
from config import (
CONFIG_DATA,
WEBUI_NAME,
ENV,
VERSION,
CHANGELOG,
FRONTEND_BUILD_DIR,
CACHE_DIR,
STATIC_DIR,
ENABLE_LITELLM,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
GLOBAL_LOG_LEVEL,
SRC_LOG_LEVELS,
WEBHOOK_URL,
ENABLE_ADMIN_EXPORT,
)
from constants import ERROR_MESSAGES
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
class SPAStaticFiles(StaticFiles): class SPAStaticFiles(StaticFiles):
@ -73,88 +35,10 @@ class SPAStaticFiles(StaticFiles):
raise ex raise ex
print(
f"""
___ __ __ _ _ _ ___
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
\___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
|_|
v{VERSION} - building the best open-source AI user interface.
https://github.com/open-webui/open-webui
"""
)
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.WEBHOOK_URL = WEBHOOK_URL
origins = ["*"] origins = ["*"]
class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
):
log.debug(f"request.url.path: {request.url.path}")
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
if "docs" in data:
data = {**data}
data["messages"] = rag_messages(
docs=data["docs"],
messages=data["messages"],
template=rag_app.state.RAG_TEMPLATE,
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.ENABLE_RAG_HYBRID_SEARCH,
)
del data["docs"]
log.debug(f"data['messages']: {data['messages']}")
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[
(k, v)
for k, v in request.headers.raw
if k.lower() != b"content-length"
],
]
response = await call_next(request)
return response
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
app.add_middleware(RAGMiddleware)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=origins,
@ -174,163 +58,35 @@ async def check_url(request: Request, call_next):
return response return response
@app.on_event("startup")
async def on_startup():
if ENABLE_LITELLM:
asyncio.create_task(start_litellm_background())
app.mount("/api/v1", webui_app) app.mount("/api/v1", webui_app)
app.mount("/litellm/api", litellm_app)
app.mount("/ollama", ollama_app) app.mount("/ollama/api", ollama_app)
app.mount("/openai/api", openai_app) app.mount("/openai/api", openai_app)
app.mount("/images/api/v1", images_app) app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app) app.mount("/audio/api/v1", audio_app)
app.mount("/rag/api/v1", rag_app) app.mount("/rag/api/v1", rag_app)
app.mount("/functions/api/v1", functions_app)
@app.get("/api/config") @app.get("/api/config")
async def get_app_config(): async def get_app_config():
# Checking and Handling the Absence of 'ui' in CONFIG_DATA
default_locale = "en-US"
if "ui" in CONFIG_DATA:
default_locale = CONFIG_DATA["ui"].get("default_locale", "en-US")
# The Rest of the Function Now Uses the Variables Defined Above
return { return {
"status": True, "status": True,
"name": WEBUI_NAME,
"version": VERSION, "version": VERSION,
"default_locale": default_locale,
"images": images_app.state.ENABLED, "images": images_app.state.ENABLED,
"default_models": webui_app.state.DEFAULT_MODELS, "default_models": webui_app.state.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS,
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"admin_export_enabled": ENABLE_ADMIN_EXPORT,
}
@app.get("/api/config/model/filter")
async def get_model_filter_config(user=Depends(get_admin_user)):
return {
"enabled": app.state.ENABLE_MODEL_FILTER,
"models": app.state.MODEL_FILTER_LIST,
}
class ModelFilterConfigForm(BaseModel):
enabled: bool
models: List[str]
@app.post("/api/config/model/filter")
async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
):
app.state.ENABLE_MODEL_FILTER = form_data.enabled
app.state.MODEL_FILTER_LIST = form_data.models
ollama_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
openai_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
litellm_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
litellm_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
return {
"enabled": app.state.ENABLE_MODEL_FILTER,
"models": app.state.MODEL_FILTER_LIST,
}
@app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)):
return {
"url": app.state.WEBHOOK_URL,
}
class UrlForm(BaseModel):
url: str
@app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
app.state.WEBHOOK_URL = form_data.url
webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL
return {
"url": app.state.WEBHOOK_URL,
}
@app.get("/api/version")
async def get_app_config():
return {
"version": VERSION,
} }
@app.get("/api/changelog") @app.get("/api/changelog")
async def get_app_changelog(): async def get_app_changelog():
return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} return CHANGELOG
@app.get("/api/version/updates") app.mount(
async def get_app_latest_release_version(): "/",
try: SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
async with aiohttp.ClientSession() as session: name="spa-static-files",
async with session.get( )
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
) as response:
response.raise_for_status()
data = await response.json()
latest_version = data["tag_name"]
return {"current": VERSION, "latest": latest_version[1:]}
except aiohttp.ClientError as e:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
)
@app.get("/manifest.json")
async def get_manifest_json():
return {
"name": WEBUI_NAME,
"short_name": WEBUI_NAME,
"start_url": "/",
"display": "standalone",
"background_color": "#343541",
"theme_color": "#343541",
"orientation": "portrait-primary",
"icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
}
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
if os.path.exists(FRONTEND_BUILD_DIR):
app.mount(
"/",
SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
name="spa-static-files",
)
else:
log.warning(
f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only."
)
@app.on_event("shutdown")
async def shutdown_event():
if ENABLE_LITELLM:
await shutdown_litellm_background()

View file

@ -1,62 +1,38 @@
fastapi==0.109.2 fastapi
uvicorn[standard]==0.22.0 uvicorn[standard]
pydantic==2.7.1 pydantic
python-multipart==0.0.9 python-multipart
Flask==3.0.3 flask
Flask-Cors==4.0.0 flask_cors
python-socketio==5.11.2 python-socketio
python-jose==3.3.0 python-jose
passlib[bcrypt]==1.7.4 passlib[bcrypt]
uuid==1.30 uuid
requests==2.31.0 requests
aiohttp==3.9.5 aiohttp
peewee==3.17.3 peewee
peewee-migrate==1.12.2 bcrypt
psycopg2-binary==2.9.9
PyMySQL==1.1.0
bcrypt==4.1.2
litellm==1.35.28 langchain
litellm[proxy]==1.35.28 langchain-community
chromadb
sentence_transformers
pypdf
docx2txt
unstructured
markdown
pypandoc
pandas
openpyxl
pyxlsb
xlrd
boto3==1.34.95 faster-whisper
argon2-cffi==23.1.0 PyJWT
APScheduler==3.10.4 pyjwt[crypto]
google-generativeai==0.5.2
langchain==0.1.16 black
langchain-community==0.0.34
langchain-chroma==0.1.0
fake-useragent==1.5.1
chromadb==0.4.24
sentence-transformers==2.7.0
pypdf==4.2.0
docx2txt==0.8
unstructured==0.11.8
Markdown==3.6
pypandoc==1.13
pandas==2.2.2
openpyxl==3.1.2
pyxlsb==1.0.10
xlrd==2.0.1
validators==0.28.1
opencv-python-headless==4.9.0.80
rapidocr-onnxruntime==1.2.3
fpdf2==2.7.8
rank-bm25==0.2.2
faster-whisper==1.0.1
PyJWT==2.8.0
PyJWT[crypto]==2.8.0
black==24.4.2
langfuse==2.27.3
youtube-transcript-api

View file

@ -6,28 +6,17 @@ cd "$SCRIPT_DIR" || exit
KEY_FILE=.webui_secret_key KEY_FILE=.webui_secret_key
PORT="${PORT:-8080}" PORT="${PORT:-8080}"
HOST="${HOST:-0.0.0.0}"
if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
echo "No WEBUI_SECRET_KEY provided" echo No WEBUI_SECRET_KEY provided
if ! [ -e "$KEY_FILE" ]; then if ! [ -e "$KEY_FILE" ]; then
echo "Generating WEBUI_SECRET_KEY" echo Generating WEBUI_SECRET_KEY
# Generate a random value to use as a WEBUI_SECRET_KEY in case the user didn't provide one. # Generate a random value to use as a WEBUI_SECRET_KEY in case the user didn't provide one.
echo $(head -c 12 /dev/random | base64) > "$KEY_FILE" echo $(head -c 12 /dev/random | base64) > $KEY_FILE
fi fi
echo "Loading WEBUI_SECRET_KEY from $KEY_FILE" echo Loading WEBUI_SECRET_KEY from $KEY_FILE
WEBUI_SECRET_KEY=$(cat "$KEY_FILE") WEBUI_SECRET_KEY=`cat $KEY_FILE`
fi fi
if [ "$USE_OLLAMA_DOCKER" = "true" ]; then WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn main:app --host 0.0.0.0 --port "$PORT" --forwarded-allow-ips '*'
echo "USE_OLLAMA is set to true, starting ollama serve."
ollama serve &
fi
if [ "$USE_CUDA_DOCKER" = "true" ]; then
echo "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib"
fi
WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn main:app --host "$HOST" --port "$PORT" --forwarded-allow-ips '*'

View file

@ -7,7 +7,7 @@ SET "SCRIPT_DIR=%~dp0"
cd /d "%SCRIPT_DIR%" || exit /b cd /d "%SCRIPT_DIR%" || exit /b
SET "KEY_FILE=.webui_secret_key" SET "KEY_FILE=.webui_secret_key"
IF "%PORT%"=="" SET PORT=8080 SET "PORT=%PORT:8080%"
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%" SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%" SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6 KiB

View file

@ -1 +0,0 @@
Name,Email,Password,Role
1 Name Email Password Role

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6 KiB

View file

@ -1,8 +1,6 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends from fastapi import HTTPException, status, Depends
from apps.web.models.users import Users from apps.web.models.users import Users
from pydantic import BaseModel from pydantic import BaseModel
from typing import Union, Optional from typing import Union, Optional
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -10,7 +8,6 @@ from passlib.context import CryptContext
from datetime import datetime, timedelta from datetime import datetime, timedelta
import requests import requests
import jwt import jwt
import uuid
import logging import logging
import config import config
@ -61,26 +58,9 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :] return auth_header[len("Bearer ") :]
def create_api_key():
key = str(uuid.uuid4()).replace("-", "")
return f"sk-{key}"
def get_http_authorization_cred(auth_header: str):
try:
scheme, credentials = auth_header.split(" ")
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
except:
raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
def get_current_user( def get_current_user(
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
): ):
# auth by api key
if auth_token.credentials.startswith("sk-"):
return get_current_user_by_api_key(auth_token.credentials)
# auth by jwt token
data = decode_token(auth_token.credentials) data = decode_token(auth_token.credentials)
if data != None and "id" in data: if data != None and "id" in data:
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
@ -89,8 +69,6 @@ def get_current_user(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
else:
Users.update_user_last_active_by_id(user.id)
return user return user
else: else:
raise HTTPException( raise HTTPException(
@ -99,20 +77,6 @@ def get_current_user(
) )
def get_current_user_by_api_key(api_key: str):
user = Users.get_user_by_api_key(api_key)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
Users.update_user_last_active_by_id(user.id)
return user
def get_verified_user(user=Depends(get_current_user)): def get_verified_user(user=Depends(get_current_user)):
if user.role not in {"user", "admin"}: if user.role not in {"user", "admin"}:
raise HTTPException( raise HTTPException(

View file

@ -1,54 +0,0 @@
import json
import requests
import logging
from config import SRC_LOG_LEVELS, VERSION, WEBUI_FAVICON_URL, WEBUI_NAME
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
def post_webhook(url: str, message: str, event_data: dict) -> bool:
try:
payload = {}
# Slack and Google Chat Webhooks
if "https://hooks.slack.com" in url or "https://chat.googleapis.com" in url:
payload["text"] = message
# Discord Webhooks
elif "https://discord.com/api/webhooks" in url:
payload["content"] = message
# Microsoft Teams Webhooks
elif "webhook.office.com" in url:
action = event_data.get("action", "undefined")
facts = [
{"name": name, "value": value}
for name, value in json.loads(event_data.get("user", {})).items()
]
payload = {
"@type": "MessageCard",
"@context": "http://schema.org/extensions",
"themeColor": "0076D7",
"summary": message,
"sections": [
{
"activityTitle": message,
"activitySubtitle": f"{WEBUI_NAME} ({VERSION}) - {action}",
"activityImage": WEBUI_FAVICON_URL,
"facts": facts,
"markdown": True,
}
],
}
# Default Payload
else:
payload = {**event_data}
log.debug(f"payload: {payload}")
r = requests.post(url, json=payload)
r.raise_for_status()
log.debug(f"r.text: {r.text}")
return True
except Exception as e:
log.exception(e)
return False

View file

@ -1,13 +0,0 @@
#!/bin/bash
echo "Warning: This will remove all containers and volumes, including persistent data. Do you want to continue? [Y/N]"
read ans
if [ "$ans" == "Y" ] || [ "$ans" == "y" ]; then
command docker-compose 2>/dev/null
if [ "$?" == "0" ]; then
docker-compose down -v
else
docker compose down -v
fi
else
echo "Operation cancelled."
fi

View file

@ -1,8 +0,0 @@
import { defineConfig } from 'cypress';
export default defineConfig({
e2e: {
baseUrl: 'http://localhost:8080'
},
video: true
});

View file

@ -1,46 +0,0 @@
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
/// <reference path="../support/index.d.ts" />
// These tests run through the chat flow.
describe('Settings', () => {
// Wait for 2 seconds after all tests to fix an issue with Cypress's video recording missing the last few frames
after(() => {
// eslint-disable-next-line cypress/no-unnecessary-waiting
cy.wait(2000);
});
beforeEach(() => {
// Login as the admin user
cy.loginAdmin();
// Visit the home page
cy.visit('/');
});
context('Ollama', () => {
it('user can select a model', () => {
// Click on the model selector
cy.get('button[aria-label="Select a model"]').click();
// Select the first model
cy.get('button[aria-label="model-item"]').first().click();
});
it('user can perform text chat', () => {
// Click on the model selector
cy.get('button[aria-label="Select a model"]').click();
// Select the first model
cy.get('button[aria-label="model-item"]').first().click();
// Type a message
cy.get('#chat-textarea').type('Hi, what can you do? A single sentence only please.', {
force: true
});
// Send the message
cy.get('button[type="submit"]').click();
// User's message should be visible
cy.get('.chat-user').should('exist');
// Wait for the response
cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received
.find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received
.should('exist');
});
});
});

View file

@ -1,52 +0,0 @@
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
/// <reference path="../support/index.d.ts" />
import { adminUser } from '../support/e2e';
// These tests assume the following defaults:
// 1. No users exist in the database or that the test admin user is an admin
// 2. Language is set to English
// 3. The default role for new users is 'pending'
describe('Registration and Login', () => {
// Wait for 2 seconds after all tests to fix an issue with Cypress's video recording missing the last few frames
after(() => {
// eslint-disable-next-line cypress/no-unnecessary-waiting
cy.wait(2000);
});
beforeEach(() => {
cy.visit('/');
});
it('should register a new user as pending', () => {
const userName = `Test User - ${Date.now()}`;
const userEmail = `cypress-${Date.now()}@example.com`;
// Toggle from sign in to sign up
cy.contains('Sign up').click();
// Fill out the form
cy.get('input[autocomplete="name"]').type(userName);
cy.get('input[autocomplete="email"]').type(userEmail);
cy.get('input[type="password"]').type('password');
// Submit the form
cy.get('button[type="submit"]').click();
// Wait until the user is redirected to the home page
cy.contains(userName);
// Expect the user to be pending
cy.contains('Check Again');
});
it('can login with the admin user', () => {
// Fill out the form
cy.get('input[autocomplete="email"]').type(adminUser.email);
cy.get('input[type="password"]').type(adminUser.password);
// Submit the form
cy.get('button[type="submit"]').click();
// Wait until the user is redirected to the home page
cy.contains(adminUser.name);
// Dismiss the changelog dialog if it is visible
cy.getAllLocalStorage().then((ls) => {
if (!ls['version']) {
cy.get('button').contains("Okay, Let's Go!").click();
}
});
});
});

View file

@ -1,88 +0,0 @@
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
/// <reference path="../support/index.d.ts" />
import { adminUser } from '../support/e2e';
// These tests run through the various settings pages, ensuring that the user can interact with them as expected
describe('Settings', () => {
// Wait for 2 seconds after all tests to fix an issue with Cypress's video recording missing the last few frames
after(() => {
// eslint-disable-next-line cypress/no-unnecessary-waiting
cy.wait(2000);
});
beforeEach(() => {
// Login as the admin user
cy.loginAdmin();
// Visit the home page
cy.visit('/');
// Open the sidebar if it is not already open
cy.get('[aria-label="Open sidebar"]').then(() => {
cy.get('button[id="sidebar-toggle-button"]').click();
});
// Click on the profile link
cy.get('button').contains(adminUser.name).click();
// Click on the settings link
cy.get('button').contains('Settings').click();
});
context('General', () => {
it('user can open the General modal and hit save', () => {
cy.get('button').contains('General').click();
cy.get('button').contains('Save').click();
});
});
context('Connections', () => {
it('user can open the Connections modal and hit save', () => {
cy.get('button').contains('Connections').click();
cy.get('button').contains('Save').click();
});
});
context('Models', () => {
it('user can open the Models modal', () => {
cy.get('button').contains('Models').click();
});
});
context('Interface', () => {
it('user can open the Interface modal and hit save', () => {
cy.get('button').contains('Interface').click();
cy.get('button').contains('Save').click();
});
});
context('Audio', () => {
it('user can open the Audio modal and hit save', () => {
cy.get('button').contains('Audio').click();
cy.get('button').contains('Save').click();
});
});
context('Images', () => {
it('user can open the Images modal and hit save', () => {
cy.get('button').contains('Images').click();
// Currently fails because the backend requires a valid URL
// cy.get('button').contains('Save').click();
});
});
context('Chats', () => {
it('user can open the Chats modal', () => {
cy.get('button').contains('Chats').click();
});
});
context('Account', () => {
it('user can open the Account modal and hit save', () => {
cy.get('button').contains('Account').click();
cy.get('button').contains('Save').click();
});
});
context('About', () => {
it('user can open the About modal', () => {
cy.get('button').contains('About').click();
});
});
});

View file

@ -1,73 +0,0 @@
/// <reference types="cypress" />
export const adminUser = {
name: 'Admin User',
email: 'admin@example.com',
password: 'password'
};
const login = (email: string, password: string) => {
return cy.session(
email,
() => {
// Visit auth page
cy.visit('/auth');
// Fill out the form
cy.get('input[autocomplete="email"]').type(email);
cy.get('input[type="password"]').type(password);
// Submit the form
cy.get('button[type="submit"]').click();
// Wait until the user is redirected to the home page
cy.get('#chat-search').should('exist');
// Get the current version to skip the changelog dialog
if (localStorage.getItem('version') === null) {
cy.get('button').contains("Okay, Let's Go!").click();
}
},
{
validate: () => {
cy.request({
method: 'GET',
url: '/api/v1/auths/',
headers: {
Authorization: 'Bearer ' + localStorage.getItem('token')
}
});
}
}
);
};
const register = (name: string, email: string, password: string) => {
return cy
.request({
method: 'POST',
url: '/api/v1/auths/signup',
body: {
name: name,
email: email,
password: password
},
failOnStatusCode: false
})
.then((response) => {
expect(response.status).to.be.oneOf([200, 400]);
});
};
const registerAdmin = () => {
return register(adminUser.name, adminUser.email, adminUser.password);
};
const loginAdmin = () => {
return login(adminUser.email, adminUser.password);
};
Cypress.Commands.add('login', (email, password) => login(email, password));
Cypress.Commands.add('register', (name, email, password) => register(name, email, password));
Cypress.Commands.add('registerAdmin', () => registerAdmin());
Cypress.Commands.add('loginAdmin', () => loginAdmin());
before(() => {
cy.registerAdmin();
});

View file

@ -1,11 +0,0 @@
// load the global Cypress types
/// <reference types="cypress" />
declare namespace Cypress {
interface Chainable {
login(email: string, password: string): Chainable<Element>;
register(name: string, email: string, password: string): Chainable<Element>;
registerAdmin(): Chainable<Element>;
loginAdmin(): Chainable<Element>;
}
}

View file

@ -1,7 +0,0 @@
{
"extends": "../tsconfig.json",
"compilerOptions": {
"inlineSourceMap": true,
"sourceMap": false
}
}

BIN
demo.gif

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5 MiB

After

Width:  |  Height:  |  Size: 6.1 MiB

View file

@ -1,8 +0,0 @@
services:
ollama:
devices:
- /dev/kfd:/dev/kfd
- /dev/dri:/dev/dri
image: ollama/ollama:${OLLAMA_DOCKER_TAG-rocm}
environment:
- 'HSA_OVERRIDE_GFX_VERSION=${HSA_OVERRIDE_GFX_VERSION-11.0.0}'

View file

@ -8,24 +8,24 @@ services:
pull_policy: always pull_policy: always
tty: true tty: true
restart: unless-stopped restart: unless-stopped
image: ollama/ollama:${OLLAMA_DOCKER_TAG-latest} image: ollama/ollama:latest
open-webui: open-webui:
build: build:
context: . context: .
args: args:
OLLAMA_BASE_URL: '/ollama' OLLAMA_API_BASE_URL: '/ollama/api'
dockerfile: Dockerfile dockerfile: Dockerfile
image: ghcr.io/open-webui/open-webui:${WEBUI_DOCKER_TAG-main} image: ghcr.io/open-webui/open-webui:main
container_name: open-webui container_name: open-webui
volumes: volumes:
- open-webui:/app/backend/data - open-webui:/app/backend/data
depends_on: depends_on:
- ollama - ollama
ports: ports:
- ${OPEN_WEBUI_PORT-3000}:8080 - ${OLLAMA_WEBUI_PORT-3000}:8080
environment: environment:
- 'OLLAMA_BASE_URL=http://ollama:11434' - 'OLLAMA_API_BASE_URL=http://ollama:11434/api'
- 'WEBUI_SECRET_KEY=' - 'WEBUI_SECRET_KEY='
extra_hosts: extra_hosts:
- host.docker.internal:host-gateway - host.docker.internal:host-gateway

View file

@ -50,18 +50,6 @@ We welcome pull requests. Before submitting one, please:
Help us make Open WebUI more accessible by improving documentation, writing tutorials, or creating guides on setting up and optimizing the web UI. Help us make Open WebUI more accessible by improving documentation, writing tutorials, or creating guides on setting up and optimizing the web UI.
### 🌐 Translations and Internationalization
Help us make Open WebUI available to a wider audience. In this section, we'll guide you through the process of adding new translations to the project.
We use JSON files to store translations. You can find the existing translation files in the `src/lib/i18n/locales` directory. Each directory corresponds to a specific language, for example, `en-US` for English (US), `fr-FR` for French (France) and so on. You can refer to [ISO 639 Language Codes][http://www.lingoes.net/en/translator/langcode.htm] to find the appropriate code for a specific language.
To add a new language:
- Create a new directory in the `src/lib/i18n/locales` path with the appropriate language code as its name. For instance, if you're adding translations for Spanish (Spain), create a new directory named `es-ES`.
- Copy the American English translation file(s) (from `en-US` directory in `src/lib/i18n/locale`) to this new directory and update the string values in JSON format according to your language. Make sure to preserve the structure of the JSON object.
- Add the language code and its respective title to languages file at `src/lib/i18n/locales/languages.json`.
### 🤔 Questions & Feedback ### 🤔 Questions & Feedback
Got questions or feedback? Join our [Discord community](https://discord.gg/5rJgQTnV4s) or open an issue. We're here to help! Got questions or feedback? Join our [Discord community](https://discord.gg/5rJgQTnV4s) or open an issue. We're here to help!

View file

@ -1,38 +0,0 @@
// i18next-parser.config.ts
import { getLanguages } from './src/lib/i18n/index.ts';
const getLangCodes = async () => {
const languages = await getLanguages();
return languages.map((l) => l.code);
};
export default {
contextSeparator: '_',
createOldCatalogs: false,
defaultNamespace: 'translation',
defaultValue: '',
indentation: 2,
keepRemoved: false,
keySeparator: false,
lexers: {
svelte: ['JavascriptLexer'],
js: ['JavascriptLexer'],
ts: ['JavascriptLexer'],
default: ['JavascriptLexer']
},
lineEnding: 'auto',
locales: await getLangCodes(),
namespaceSeparator: false,
output: 'src/lib/i18n/locales/$LOCALE/$NAMESPACE.json',
pluralSeparator: '_',
input: 'src/**/*.{js,svelte}',
sort: true,
verbose: true,
failOnWarnings: false,
failOnUpdate: false,
customValueTemplate: null,
resetDefaultValueLocale: null,
i18nextOptions: null,
yamlOptions: null
};

View file

@ -1 +0,0 @@
values-minikube.yaml

View file

@ -1,21 +1,5 @@
apiVersion: v2 apiVersion: v2
name: open-webui name: open-webui
version: 1.0.0
appVersion: "latest"
home: https://www.openwebui.com/
icon: https://raw.githubusercontent.com/open-webui/open-webui/main/static/favicon.png
description: "Open WebUI: A User-Friendly Web Interface for Chat Interactions 👋" description: "Open WebUI: A User-Friendly Web Interface for Chat Interactions 👋"
keywords: version: 1.0.0
- llm icon: https://raw.githubusercontent.com/open-webui/open-webui/main/static/favicon.png
- chat
- web-ui
sources:
- https://github.com/open-webui/open-webui/tree/main/kubernetes/helm
- https://hub.docker.com/r/ollama/ollama
- https://github.com/open-webui/open-webui/pkgs/container/open-webui
annotations:
licenses: MIT

View file

@ -1,51 +0,0 @@
{{- define "open-webui.name" -}}
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
{{- end -}}
{{- define "ollama.name" -}}
ollama
{{- end -}}
{{- define "ollama.url" -}}
{{- if .Values.ollama.externalHost }}
{{- printf .Values.ollama.externalHost }}
{{- else }}
{{- printf "http://%s.%s.svc.cluster.local:%d" (include "ollama.name" .) (.Release.Namespace) (.Values.ollama.service.port | int) }}
{{- end }}
{{- end }}
{{- define "chart.name" -}}
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
{{- end }}
{{- define "base.labels" -}}
helm.sh/chart: {{ include "chart.name" . }}
{{- if .Chart.AppVersion }}
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
{{- end }}
app.kubernetes.io/managed-by: {{ .Release.Service }}
{{- end }}
{{- define "base.selectorLabels" -}}
app.kubernetes.io/instance: {{ .Release.Name }}
{{- end -}}
{{- define "open-webui.selectorLabels" -}}
{{ include "base.selectorLabels" . }}
app.kubernetes.io/component: {{ .Chart.Name }}
{{- end }}
{{- define "open-webui.labels" -}}
{{ include "base.labels" . }}
{{ include "open-webui.selectorLabels" . }}
{{- end }}
{{- define "ollama.selectorLabels" -}}
{{ include "base.selectorLabels" . }}
app.kubernetes.io/component: {{ include "ollama.name" . }}
{{- end }}
{{- define "ollama.labels" -}}
{{ include "base.labels" . }}
{{ include "ollama.selectorLabels" . }}
{{- end }}

View file

@ -0,0 +1,4 @@
apiVersion: v1
kind: Namespace
metadata:
name: {{ .Values.namespace }}

View file

@ -1,23 +1,13 @@
{{- if not .Values.ollama.externalHost }}
apiVersion: v1 apiVersion: v1
kind: Service kind: Service
metadata: metadata:
name: {{ include "ollama.name" . }} name: ollama-service
labels: namespace: {{ .Values.namespace }}
{{- include "ollama.labels" . | nindent 4 }}
{{- with .Values.ollama.service.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec: spec:
type: {{ .Values.ollama.service.type }}
selector: selector:
{{- include "ollama.selectorLabels" . | nindent 4 }} app: ollama
{{- with .Values.ollama.service }}
type: {{ .type }}
ports: ports:
- protocol: TCP - protocol: TCP
name: http port: {{ .Values.ollama.servicePort }}
port: {{ .port }} targetPort: {{ .Values.ollama.servicePort }}
targetPort: http
{{- end }}
{{- end }}

View file

@ -1,44 +1,24 @@
{{- if not .Values.ollama.externalHost }}
apiVersion: apps/v1 apiVersion: apps/v1
kind: StatefulSet kind: StatefulSet
metadata: metadata:
name: {{ include "ollama.name" . }} name: ollama
labels: namespace: {{ .Values.namespace }}
{{- include "ollama.labels" . | nindent 4 }}
{{- with .Values.ollama.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec: spec:
serviceName: {{ include "ollama.name" . }} serviceName: "ollama"
replicas: {{ .Values.ollama.replicaCount }} replicas: {{ .Values.ollama.replicaCount }}
selector: selector:
matchLabels: matchLabels:
{{- include "ollama.selectorLabels" . | nindent 6 }} app: ollama
template: template:
metadata: metadata:
labels: labels:
{{- include "ollama.labels" . | nindent 8 }} app: ollama
{{- with .Values.ollama.podAnnotations }}
annotations:
{{- toYaml . | nindent 8 }}
{{- end }}
spec: spec:
enableServiceLinks: false
automountServiceAccountToken: false
{{- with .Values.ollama.runtimeClassName }}
runtimeClassName: {{ . }}
{{- end }}
containers: containers:
- name: {{ include "ollama.name" . }} - name: ollama
{{- with .Values.ollama.image }} image: {{ .Values.ollama.image }}
image: {{ .repository }}:{{ .tag }}
imagePullPolicy: {{ .pullPolicy }}
{{- end }}
tty: true
ports: ports:
- name: http - containerPort: {{ .Values.ollama.servicePort }}
containerPort: {{ .Values.ollama.service.containerPort }}
env: env:
{{- if .Values.ollama.gpu.enabled }} {{- if .Values.ollama.gpu.enabled }}
- name: PATH - name: PATH
@ -47,52 +27,29 @@ spec:
value: /usr/local/nvidia/lib:/usr/local/nvidia/lib64 value: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
- name: NVIDIA_DRIVER_CAPABILITIES - name: NVIDIA_DRIVER_CAPABILITIES
value: compute,utility value: compute,utility
{{- end }} {{- end}}
{{- with .Values.ollama.resources }} {{- if .Values.ollama.resources }}
resources: {{- toYaml . | nindent 10 }} resources: {{- toYaml .Values.ollama.resources | nindent 10 }}
{{- end }} {{- end }}
volumeMounts: volumeMounts:
- name: data - name: ollama-volume
mountPath: /root/.ollama mountPath: /root/.ollama
tty: true
{{- with .Values.ollama.nodeSelector }} {{- with .Values.ollama.nodeSelector }}
nodeSelector: nodeSelector:
{{- toYaml . | nindent 8 }} {{- toYaml . | nindent 8 }}
{{- end }} {{- end }}
{{- with .Values.ollama.tolerations }}
tolerations: tolerations:
{{- toYaml . | nindent 8 }} {{- if .Values.ollama.gpu.enabled }}
{{- end }} - key: nvidia.com/gpu
volumes: operator: Exists
{{- if and .Values.ollama.persistence.enabled .Values.ollama.persistence.existingClaim }} effect: NoSchedule
- name: data {{- end }}
persistentVolumeClaim:
claimName: {{ .Values.ollama.persistence.existingClaim }}
{{- else if not .Values.ollama.persistence.enabled }}
- name: data
emptyDir: {}
{{- else if and .Values.ollama.persistence.enabled (not .Values.ollama.persistence.existingClaim) }}
[]
volumeClaimTemplates: volumeClaimTemplates:
- metadata: - metadata:
name: data name: ollama-volume
labels:
{{- include "ollama.selectorLabels" . | nindent 8 }}
{{- with .Values.ollama.persistence.annotations }}
annotations:
{{- toYaml . | nindent 8 }}
{{- end }}
spec: spec:
accessModes: accessModes: [ "ReadWriteOnce" ]
{{- range .Values.ollama.persistence.accessModes }}
- {{ . | quote }}
{{- end }}
resources: resources:
requests: requests:
storage: {{ .Values.ollama.persistence.size | quote }} storage: {{ .Values.ollama.volumeSize }}
storageClassName: {{ .Values.ollama.persistence.storageClass }}
{{- with .Values.ollama.persistence.selector }}
selector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- end }}
{{- end }}

View file

@ -1,62 +1,38 @@
apiVersion: apps/v1 apiVersion: apps/v1
kind: Deployment kind: Deployment
metadata: metadata:
name: {{ include "open-webui.name" . }} name: open-webui-deployment
labels: namespace: {{ .Values.namespace }}
{{- include "open-webui.labels" . | nindent 4 }}
{{- with .Values.webui.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec: spec:
replicas: {{ .Values.webui.replicaCount }} replicas: 1
selector: selector:
matchLabels: matchLabels:
{{- include "open-webui.selectorLabels" . | nindent 6 }} app: open-webui
template: template:
metadata: metadata:
labels: labels:
{{- include "open-webui.labels" . | nindent 8 }} app: open-webui
{{- with .Values.webui.podAnnotations }}
annotations:
{{- toYaml . | nindent 8 }}
{{- end }}
spec: spec:
enableServiceLinks: false
automountServiceAccountToken: false
containers: containers:
- name: {{ .Chart.Name }} - name: open-webui
{{- with .Values.webui.image }} image: {{ .Values.webui.image }}
image: {{ .repository }}:{{ .tag | default $.Chart.AppVersion }}
imagePullPolicy: {{ .pullPolicy }}
{{- end }}
ports: ports:
- name: http - containerPort: 8080
containerPort: {{ .Values.webui.service.containerPort }} {{- if .Values.webui.resources }}
{{- with .Values.webui.resources }} resources: {{- toYaml .Values.webui.resources | nindent 10 }}
resources: {{- toYaml . | nindent 10 }}
{{- end }} {{- end }}
volumeMounts: volumeMounts:
- name: data - name: webui-volume
mountPath: /app/backend/data mountPath: /app/backend/data
env: env:
- name: OLLAMA_BASE_URL - name: OLLAMA_API_BASE_URL
value: {{ include "ollama.url" . | quote }} value: "http://ollama-service.{{ .Values.namespace }}.svc.cluster.local:{{ .Values.ollama.servicePort }}/api"
tty: true tty: true
{{- with .Values.webui.nodeSelector }} {{- with .Values.webui.nodeSelector }}
nodeSelector: nodeSelector:
{{- toYaml . | nindent 8 }} {{- toYaml . | nindent 8 }}
{{- end }} {{- end }}
volumes: volumes:
{{- if and .Values.webui.persistence.enabled .Values.webui.persistence.existingClaim }} - name: webui-volume
- name: data
persistentVolumeClaim: persistentVolumeClaim:
claimName: {{ .Values.webui.persistence.existingClaim }} claimName: open-webui-pvc
{{- else if not .Values.webui.persistence.enabled }}
- name: data
emptyDir: {}
{{- else if and .Values.webui.persistence.enabled (not .Values.webui.persistence.existingClaim) }}
- name: data
persistentVolumeClaim:
claimName: {{ include "open-webui.name" . }}
{{- end }}

View file

@ -2,23 +2,13 @@
apiVersion: networking.k8s.io/v1 apiVersion: networking.k8s.io/v1
kind: Ingress kind: Ingress
metadata: metadata:
name: {{ include "open-webui.name" . }} name: open-webui-ingress
labels: namespace: {{ .Values.namespace }}
{{- include "open-webui.labels" . | nindent 4 }} {{- if .Values.webui.ingress.annotations }}
{{- with .Values.webui.ingress.annotations }}
annotations: annotations:
{{- toYaml . | nindent 4 }} {{ toYaml .Values.webui.ingress.annotations | trimSuffix "\n" | indent 4 }}
{{- end }} {{- end }}
spec: spec:
{{- with .Values.webui.ingress.class }}
ingressClassName: {{ . }}
{{- end }}
{{- if .Values.webui.ingress.tls }}
tls:
- hosts:
- {{ .Values.webui.ingress.host | quote }}
secretName: {{ default (printf "%s-tls" .Release.Name) .Values.webui.ingress.existingSecret }}
{{- end }}
rules: rules:
- host: {{ .Values.webui.ingress.host }} - host: {{ .Values.webui.ingress.host }}
http: http:
@ -27,7 +17,7 @@ spec:
pathType: Prefix pathType: Prefix
backend: backend:
service: service:
name: {{ include "open-webui.name" . }} name: open-webui-service
port: port:
name: http number: {{ .Values.webui.servicePort }}
{{- end }} {{- end }}

View file

@ -1,27 +1,12 @@
{{- if and .Values.webui.persistence.enabled (not .Values.webui.persistence.existingClaim) }}
apiVersion: v1 apiVersion: v1
kind: PersistentVolumeClaim kind: PersistentVolumeClaim
metadata: metadata:
name: {{ include "open-webui.name" . }}
labels: labels:
{{- include "open-webui.selectorLabels" . | nindent 4 }} app: open-webui
{{- with .Values.webui.persistence.annotations }} name: open-webui-pvc
annotations: namespace: {{ .Values.namespace }}
{{- toYaml . | nindent 8 }}
{{- end }}
spec: spec:
accessModes: accessModes: [ "ReadWriteOnce" ]
{{- range .Values.webui.persistence.accessModes }}
- {{ . | quote }}
{{- end }}
resources: resources:
requests: requests:
storage: {{ .Values.webui.persistence.size }} storage: {{ .Values.webui.volumeSize }}
{{- if .Values.webui.persistence.storageClass }}
storageClassName: {{ .Values.webui.persistence.storageClass }}
{{- end }}
{{- with .Values.webui.persistence.selector }}
selector:
{{- toYaml . | nindent 4 }}
{{- end }}
{{- end }}

View file

@ -1,29 +1,15 @@
apiVersion: v1 apiVersion: v1
kind: Service kind: Service
metadata: metadata:
name: {{ include "open-webui.name" . }} name: open-webui-service
labels: namespace: {{ .Values.namespace }}
{{- include "open-webui.labels" . | nindent 4 }}
{{- with .Values.webui.service.labels }}
{{- toYaml . | nindent 4 }}
{{- end }}
{{- with .Values.webui.service.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec: spec:
type: {{ .Values.webui.service.type }} # Default: NodePort # Use LoadBalancer if you're on a cloud that supports it
selector: selector:
{{- include "open-webui.selectorLabels" . | nindent 4 }} app: open-webui
type: {{ .Values.webui.service.type | default "ClusterIP" }}
ports: ports:
- protocol: TCP - protocol: TCP
name: http port: {{ .Values.webui.servicePort }}
port: {{ .Values.webui.service.port }} targetPort: {{ .Values.webui.servicePort }}
targetPort: http # If using NodePort, you can optionally specify the nodePort:
{{- if .Values.webui.service.nodePort }} # nodePort: 30000
nodePort: {{ .Values.webui.service.nodePort | int }}
{{- end }}
{{- if .Values.webui.service.loadBalancerClass }}
loadBalancerClass: {{ .Values.webui.service.loadBalancerClass | quote }}
{{- end }}

View file

@ -1,27 +0,0 @@
ollama:
resources:
requests:
cpu: "2000m"
memory: "2Gi"
limits:
cpu: "4000m"
memory: "4Gi"
nvidia.com/gpu: "0"
service:
type: ClusterIP
gpu:
enabled: false
webui:
resources:
requests:
cpu: "500m"
memory: "500Mi"
limits:
cpu: "1000m"
memory: "1Gi"
ingress:
enabled: true
host: open-webui.minikube.local
service:
type: NodePort

View file

@ -1,75 +1,44 @@
nameOverride: "" namespace: open-webui
ollama: ollama:
externalHost: ""
annotations: {}
podAnnotations: {}
replicaCount: 1 replicaCount: 1
image: image: ollama/ollama:latest
repository: ollama/ollama servicePort: 11434
tag: latest resources:
pullPolicy: Always requests:
resources: {} cpu: "2000m"
persistence: memory: "2Gi"
enabled: true limits:
size: 30Gi cpu: "4000m"
existingClaim: "" memory: "4Gi"
accessModes: nvidia.com/gpu: "0"
- ReadWriteOnce volumeSize: 30Gi
storageClass: ""
selector: {}
annotations: {}
nodeSelector: {}
# -- If using a special runtime container such as nvidia, set it here.
runtimeClassName: ""
tolerations:
- key: nvidia.com/gpu
operator: Exists
effect: NoSchedule
service:
type: ClusterIP
annotations: {}
port: 80
containerPort: 11434
gpu:
# -- Enable additional ENV values to help Ollama discover GPU usage
enabled: false
webui:
annotations: {}
podAnnotations: {}
replicaCount: 1
image:
repository: ghcr.io/open-webui/open-webui
tag: ""
pullPolicy: Always
resources: {}
ingress:
enabled: false
class: ""
# -- Use appropriate annotations for your Ingress controller, e.g., for NGINX:
# nginx.ingress.kubernetes.io/rewrite-target: /
annotations: {}
host: ""
tls: false
existingSecret: ""
persistence:
enabled: true
size: 2Gi
existingClaim: ""
# -- If using multiple replicas, you must update accessModes to ReadWriteMany
accessModes:
- ReadWriteOnce
storageClass: ""
selector: {}
annotations: {}
nodeSelector: {} nodeSelector: {}
tolerations: [] tolerations: []
service: service:
type: ClusterIP type: ClusterIP
annotations: {} gpu:
port: 80 enabled: false
containerPort: 8080
nodePort: "" webui:
labels: {} replicaCount: 1
loadBalancerClass: "" image: ghcr.io/open-webui/open-webui:main
servicePort: 8080
resources:
requests:
cpu: "500m"
memory: "500Mi"
limits:
cpu: "1000m"
memory: "1Gi"
ingress:
enabled: true
annotations:
# Use appropriate annotations for your Ingress controller, e.g., for NGINX:
# nginx.ingress.kubernetes.io/rewrite-target: /
host: open-webui.minikube.local
volumeSize: 2Gi
nodeSelector: {}
tolerations: []
service:
type: NodePort

View file

@ -26,8 +26,8 @@ spec:
cpu: "1000m" cpu: "1000m"
memory: "1Gi" memory: "1Gi"
env: env:
- name: OLLAMA_BASE_URL - name: OLLAMA_API_BASE_URL
value: "http://ollama-service.open-webui.svc.cluster.local:11434" value: "http://ollama-service.open-webui.svc.cluster.local:11434/api"
tty: true tty: true
volumeMounts: volumeMounts:
- name: webui-volume - name: webui-volume
@ -35,4 +35,4 @@ spec:
volumes: volumes:
- name: webui-volume - name: webui-volume
persistentVolumeClaim: persistentVolumeClaim:
claimName: open-webui-pvc claimName: ollama-webui-pvc

View file

@ -2,9 +2,9 @@ apiVersion: v1
kind: PersistentVolumeClaim kind: PersistentVolumeClaim
metadata: metadata:
labels: labels:
app: open-webui app: ollama-webui
name: open-webui-pvc name: ollama-webui-pvc
namespace: open-webui namespace: ollama-namespace
spec: spec:
accessModes: ["ReadWriteOnce"] accessModes: ["ReadWriteOnce"]
resources: resources:

Some files were not shown because too many files have changed in this diff Show more