forked from open-webui/open-webui
Merge remote-tracking branch 'upstream/main'
Some checks failed
Release / release (push) Failing after 39s
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Failing after 21s
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Failing after 9s
Create and publish Docker images with specific build args / merge-main-images (push) Has been skipped
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64) (push) Failing after 44s
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64) (push) Failing after 25s
Create and publish Docker images with specific build args / merge-cuda-images (push) Has been skipped
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64) (push) Failing after 25s
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64) (push) Failing after 26s
Create and publish Docker images with specific build args / merge-ollama-images (push) Has been skipped
Python CI / Format Backend (3.11) (push) Successful in 40s
Frontend Build / Format & Build Frontend (push) Successful in 1m50s
Integration Test / Run Cypress Integration Tests (push) Failing after 12m2s
Integration Test / Run Migration Tests (push) Failing after 37s
Some checks failed
Release / release (push) Failing after 39s
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Failing after 21s
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Failing after 9s
Create and publish Docker images with specific build args / merge-main-images (push) Has been skipped
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64) (push) Failing after 44s
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64) (push) Failing after 25s
Create and publish Docker images with specific build args / merge-cuda-images (push) Has been skipped
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64) (push) Failing after 25s
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64) (push) Failing after 26s
Create and publish Docker images with specific build args / merge-ollama-images (push) Has been skipped
Python CI / Format Backend (3.11) (push) Successful in 40s
Frontend Build / Format & Build Frontend (push) Successful in 1m50s
Integration Test / Run Cypress Integration Tests (push) Failing after 12m2s
Integration Test / Run Migration Tests (push) Failing after 37s
This commit is contained in:
commit
71dd9bcbe5
176 changed files with 15094 additions and 3296 deletions
|
@ -10,6 +10,7 @@ OPENAI_API_KEY=''
|
|||
# DO NOT TRACK
|
||||
SCARF_NO_ANALYTICS=true
|
||||
DO_NOT_TRACK=true
|
||||
ANONYMIZED_TELEMETRY=false
|
||||
|
||||
# Use locally bundled version of the LiteLLM cost map json
|
||||
# to avoid repetitive startup connections
|
||||
|
|
|
@ -4,6 +4,7 @@ module.exports = {
|
|||
'eslint:recommended',
|
||||
'plugin:@typescript-eslint/recommended',
|
||||
'plugin:svelte/recommended',
|
||||
'plugin:cypress/recommended',
|
||||
'prettier'
|
||||
],
|
||||
parser: '@typescript-eslint/parser',
|
||||
|
|
3
.github/ISSUE_TEMPLATE/bug_report.md
vendored
3
.github/ISSUE_TEMPLATE/bug_report.md
vendored
|
@ -24,6 +24,9 @@ assignees: ''
|
|||
|
||||
## Environment
|
||||
|
||||
- **Open WebUI Version:** [e.g., 0.1.120]
|
||||
- **Ollama (if applicable):** [e.g., 0.1.30, 0.1.32-rc1]
|
||||
|
||||
- **Operating System:** [e.g., Windows 10, macOS Big Sur, Ubuntu 20.04]
|
||||
- **Browser (if applicable):** [e.g., Chrome 100.0, Firefox 98.0]
|
||||
|
||||
|
|
11
.github/dependabot.yml
vendored
Normal file
11
.github/dependabot.yml
vendored
Normal file
|
@ -0,0 +1,11 @@
|
|||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: pip
|
||||
directory: "/backend"
|
||||
schedule:
|
||||
interval: daily
|
||||
time: "13:00"
|
||||
groups:
|
||||
python-packages:
|
||||
patterns:
|
||||
- "*"
|
30
.github/pull_request_template.md
vendored
30
.github/pull_request_template.md
vendored
|
@ -2,14 +2,16 @@
|
|||
|
||||
- [ ] **Description:** Briefly describe the changes in this pull request.
|
||||
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
|
||||
- [ ] **Documentation:** Have you updated relevant documentation?
|
||||
- [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
|
||||
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
|
||||
- [ ] **Testing:** Have you written and run sufficient tests for the changes?
|
||||
- [ ] **Code Review:** Have you self-reviewed your code and addressed any coding standard issues?
|
||||
|
||||
---
|
||||
|
||||
## Description
|
||||
|
||||
[Insert a brief description of the changes made in this pull request]
|
||||
[Insert a brief description of the changes made in this pull request, including any relevant motivation and impact.]
|
||||
|
||||
---
|
||||
|
||||
|
@ -17,16 +19,32 @@
|
|||
|
||||
### Added
|
||||
|
||||
- [List any new features or additions]
|
||||
- [List any new features, functionalities, or additions]
|
||||
|
||||
### Fixed
|
||||
|
||||
- [List any fixes or corrections]
|
||||
- [List any fixes, corrections, or bug fixes]
|
||||
|
||||
### Changed
|
||||
|
||||
- [List any changes or updates]
|
||||
- [List any changes, updates, refactorings, or optimizations]
|
||||
|
||||
### Removed
|
||||
|
||||
- [List any removed features or files]
|
||||
- [List any removed features, files, or deprecated functionalities]
|
||||
|
||||
### Security
|
||||
|
||||
- [List any new or updated security-related changes, including vulnerability fixes]
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- [List any breaking changes affecting compatibility or functionality]
|
||||
|
||||
---
|
||||
|
||||
### Additional Information
|
||||
|
||||
- [Insert any additional context, notes, or explanations for the changes]
|
||||
|
||||
- [Reference any related issues, commits, or other relevant information]
|
||||
|
|
11
.github/workflows/build-release.yml
vendored
11
.github/workflows/build-release.yml
vendored
|
@ -57,3 +57,14 @@ jobs:
|
|||
path: .
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Trigger Docker build workflow
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
github.rest.actions.createWorkflowDispatch({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
workflow_id: 'docker-build.yaml',
|
||||
ref: 'v${{ steps.get_version.outputs.version }}',
|
||||
})
|
||||
|
|
351
.github/workflows/docker-build.yaml
vendored
351
.github/workflows/docker-build.yaml
vendored
|
@ -1,8 +1,7 @@
|
|||
#
|
||||
name: Create and publish a Docker image
|
||||
name: Create and publish Docker images with specific build args
|
||||
|
||||
# Configures this workflow to run every time a change is pushed to the branch called `release`.
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
@ -10,15 +9,14 @@ on:
|
|||
tags:
|
||||
- v*
|
||||
|
||||
# Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
|
||||
env:
|
||||
REGISTRY: git.depeuter.dev
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
RUNNER_TOOL_CACHE: /toolcache
|
||||
FULL_IMAGE_NAME: ${{ env.REGISTRY }}/${{ github.repository }}
|
||||
|
||||
# There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
build-main-image:
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: catthehacker/ubuntu:act-latest
|
||||
|
@ -26,17 +24,28 @@ jobs:
|
|||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
#
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
# Required for multi architecture build
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
# Required for multi architecture build
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
|
@ -44,12 +53,11 @@ jobs:
|
|||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.CI_TOKEN }}
|
||||
|
||||
- name: Extract metadata for Docker images
|
||||
- name: Extract metadata for Docker images (default latest tag)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
# This configuration dynamically generates tags based on the branch, tag, commit, and custom suffix for lite version.
|
||||
images: ${{ env.FULL_IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
|
@ -59,11 +67,322 @@ jobs:
|
|||
flavor: |
|
||||
latest=${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
- name: Build Docker image (latest)
|
||||
uses: docker/build-push-action@v5
|
||||
id: build
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-main-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
build-cuda-image:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.CI_TOKEN }}
|
||||
|
||||
- name: Extract metadata for Docker images (default latest tag)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.FULL_IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
type=sha,prefix=git-
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda
|
||||
flavor: |
|
||||
latest=${{ github.ref == 'refs/heads/main' }}
|
||||
suffix=-cuda,onlatest=true
|
||||
|
||||
- name: Build Docker image (cuda)
|
||||
uses: docker/build-push-action@v5
|
||||
id: build
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: ${{ matrix.platform }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
build-args: USE_CUDA=true
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-cuda-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
build-ollama-image:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.CI_TOKEN }}
|
||||
|
||||
- name: Extract metadata for Docker images (ollama tag)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.FULL_IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
type=sha,prefix=git-
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=ollama
|
||||
flavor: |
|
||||
latest=${{ github.ref == 'refs/heads/main' }}
|
||||
suffix=-ollama,onlatest=true
|
||||
|
||||
- name: Build Docker image (ollama)
|
||||
uses: docker/build-push-action@v5
|
||||
id: build
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: ${{ matrix.platform }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
build-args: USE_OLLAMA=true
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-ollama-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
merge-main-images:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [ build-main-image ]
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: digests-main-*
|
||||
path: /tmp/digests
|
||||
merge-multiple: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.CI_TOKEN }}
|
||||
|
||||
- name: Extract metadata for Docker images (default latest tag)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.FULL_IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
type=sha,prefix=git-
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
flavor: |
|
||||
latest=${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
|
||||
merge-cuda-images:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [ build-cuda-image ]
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: digests-cuda-*
|
||||
path: /tmp/digests
|
||||
merge-multiple: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.CI_TOKEN }}
|
||||
|
||||
- name: Extract metadata for Docker images (default latest tag)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.FULL_IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
type=sha,prefix=git-
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda
|
||||
flavor: |
|
||||
latest=${{ github.ref == 'refs/heads/main' }}
|
||||
suffix=-cuda,onlatest=true
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
merge-ollama-images:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [ build-ollama-image ]
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: digests-ollama-*
|
||||
path: /tmp/digests
|
||||
merge-multiple: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.CI_TOKEN }}
|
||||
|
||||
- name: Extract metadata for Docker images (default ollama tag)
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.FULL_IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
type=sha,prefix=git-
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=ollama
|
||||
flavor: |
|
||||
latest=${{ github.ref == 'refs/heads/main' }}
|
||||
suffix=-ollama,onlatest=true
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
|
||||
|
|
3
.github/workflows/format-build-frontend.yaml
vendored
3
.github/workflows/format-build-frontend.yaml
vendored
|
@ -29,6 +29,9 @@ jobs:
|
|||
- name: Format Frontend
|
||||
run: npm run format
|
||||
|
||||
- name: Run i18next
|
||||
run: npm run i18n:parse
|
||||
|
||||
- name: Check for Changes After Format
|
||||
run: git diff --exit-code
|
||||
|
||||
|
|
186
.github/workflows/integration-test.yml
vendored
Normal file
186
.github/workflows/integration-test.yml
vendored
Normal file
|
@ -0,0 +1,186 @@
|
|||
name: Integration Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
cypress-run:
|
||||
name: Run Cypress Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Build and run Compose Stack
|
||||
run: |
|
||||
docker compose up --detach --build
|
||||
|
||||
- name: Preload Ollama model
|
||||
run: |
|
||||
docker exec ollama ollama pull qwen:0.5b-chat-v1.5-q2_K
|
||||
|
||||
- name: Cypress run
|
||||
uses: cypress-io/github-action@v6
|
||||
with:
|
||||
browser: chrome
|
||||
wait-on: 'http://localhost:3000'
|
||||
config: baseUrl=http://localhost:3000
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
name: Upload Cypress videos
|
||||
with:
|
||||
name: cypress-videos
|
||||
path: cypress/videos
|
||||
if-no-files-found: ignore
|
||||
|
||||
- name: Extract Compose logs
|
||||
if: always()
|
||||
run: |
|
||||
docker compose logs > compose-logs.txt
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
name: Upload Compose logs
|
||||
with:
|
||||
name: compose-logs
|
||||
path: compose-logs.txt
|
||||
if-no-files-found: ignore
|
||||
|
||||
migration_test:
|
||||
name: Run Migration Tests
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
postgres:
|
||||
image: postgres
|
||||
env:
|
||||
POSTGRES_PASSWORD: postgres
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 5432:5432
|
||||
# mysql:
|
||||
# image: mysql
|
||||
# env:
|
||||
# MYSQL_ROOT_PASSWORD: mysql
|
||||
# MYSQL_DATABASE: mysql
|
||||
# options: >-
|
||||
# --health-cmd "mysqladmin ping -h localhost"
|
||||
# --health-interval 10s
|
||||
# --health-timeout 5s
|
||||
# --health-retries 5
|
||||
# ports:
|
||||
# - 3306:3306
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up uv
|
||||
uses: yezz123/setup-uv@v4
|
||||
with:
|
||||
uv-venv: venv
|
||||
|
||||
- name: Activate virtualenv
|
||||
run: |
|
||||
. venv/bin/activate
|
||||
echo PATH=$PATH >> $GITHUB_ENV
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -r backend/requirements.txt
|
||||
|
||||
- name: Test backend with SQLite
|
||||
id: sqlite
|
||||
env:
|
||||
WEBUI_SECRET_KEY: secret-key
|
||||
GLOBAL_LOG_LEVEL: debug
|
||||
run: |
|
||||
cd backend
|
||||
uvicorn main:app --port "8080" --forwarded-allow-ips '*' &
|
||||
UVICORN_PID=$!
|
||||
# Wait up to 20 seconds for the server to start
|
||||
for i in {1..20}; do
|
||||
curl -s http://localhost:8080/api/config > /dev/null && break
|
||||
sleep 1
|
||||
if [ $i -eq 20 ]; then
|
||||
echo "Server failed to start"
|
||||
kill -9 $UVICORN_PID
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
# Check that the server is still running after 5 seconds
|
||||
sleep 5
|
||||
if ! kill -0 $UVICORN_PID; then
|
||||
echo "Server has stopped"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
- name: Test backend with Postgres
|
||||
if: success() || steps.sqlite.conclusion == 'failure'
|
||||
env:
|
||||
WEBUI_SECRET_KEY: secret-key
|
||||
GLOBAL_LOG_LEVEL: debug
|
||||
DATABASE_URL: postgresql://postgres:postgres@localhost:5432/postgres
|
||||
run: |
|
||||
cd backend
|
||||
uvicorn main:app --port "8081" --forwarded-allow-ips '*' &
|
||||
UVICORN_PID=$!
|
||||
# Wait up to 20 seconds for the server to start
|
||||
for i in {1..20}; do
|
||||
curl -s http://localhost:8081/api/config > /dev/null && break
|
||||
sleep 1
|
||||
if [ $i -eq 20 ]; then
|
||||
echo "Server failed to start"
|
||||
kill -9 $UVICORN_PID
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
# Check that the server is still running after 5 seconds
|
||||
sleep 5
|
||||
if ! kill -0 $UVICORN_PID; then
|
||||
echo "Server has stopped"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# - name: Test backend with MySQL
|
||||
# if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure'
|
||||
# env:
|
||||
# WEBUI_SECRET_KEY: secret-key
|
||||
# GLOBAL_LOG_LEVEL: debug
|
||||
# DATABASE_URL: mysql://root:mysql@localhost:3306/mysql
|
||||
# run: |
|
||||
# cd backend
|
||||
# uvicorn main:app --port "8083" --forwarded-allow-ips '*' &
|
||||
# UVICORN_PID=$!
|
||||
# # Wait up to 20 seconds for the server to start
|
||||
# for i in {1..20}; do
|
||||
# curl -s http://localhost:8083/api/config > /dev/null && break
|
||||
# sleep 1
|
||||
# if [ $i -eq 20 ]; then
|
||||
# echo "Server failed to start"
|
||||
# kill -9 $UVICORN_PID
|
||||
# exit 1
|
||||
# fi
|
||||
# done
|
||||
# # Check that the server is still running after 5 seconds
|
||||
# sleep 5
|
||||
# if ! kill -0 $UVICORN_PID; then
|
||||
# echo "Server has stopped"
|
||||
# exit 1
|
||||
# fi
|
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -298,3 +298,7 @@ dist
|
|||
.yarn/build-state.yml
|
||||
.yarn/install-state.gz
|
||||
.pnp.*
|
||||
|
||||
# cypress artifacts
|
||||
cypress/videos
|
||||
cypress/screenshots
|
||||
|
|
119
CHANGELOG.md
119
CHANGELOG.md
|
@ -5,6 +5,125 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.1.123] - 2024-05-02
|
||||
|
||||
### Added
|
||||
|
||||
- **🎨 New Landing Page Design**: Refreshed design for a more modern look and optimized use of screen space.
|
||||
- **📹 Youtube RAG Pipeline**: Introduces dedicated RAG pipeline for Youtube videos, enabling interaction with video transcriptions directly.
|
||||
- **🔧 Enhanced Admin Panel**: Streamlined user management with options to add users directly or in bulk via CSV import.
|
||||
- **👥 '@' Model Integration**: Easily switch to specific models during conversations; old collaborative chat feature phased out.
|
||||
- **🌐 Language Enhancements**: Swedish translation added, plus improvements to German, Spanish, and the addition of Doge translation.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🗑️ Delete Chat Shortcut**: Addressed issue where shortcut wasn't functioning.
|
||||
- **🖼️ Modal Closing Bug**: Resolved unexpected closure of modal when dragging from within.
|
||||
- **✏️ Edit Button Styling**: Fixed styling inconsistency with edit buttons.
|
||||
- **🌐 Image Generation Compatibility Issue**: Rectified image generation compatibility issue with third-party APIs.
|
||||
- **📱 iOS PWA Icon Fix**: Corrected iOS PWA home screen icon shape.
|
||||
- **🔍 Scroll Gesture Bug**: Adjusted gesture sensitivity to prevent accidental activation when scrolling through code on mobile; now requires scrolling from the leftmost side to open the sidebar.
|
||||
|
||||
### Changed
|
||||
|
||||
- **🔄 Unlimited Context Length**: Advanced settings now allow unlimited max context length (previously limited to 16000).
|
||||
- **👑 Super Admin Assignment**: The first signup is automatically assigned a super admin role, unchangeable by other admins.
|
||||
- **🛡️ Admin User Restrictions**: User action buttons from the admin panel are now disabled for users with admin roles.
|
||||
- **🔝 Default Model Selector**: Set as default model option now exclusively available on the landing page.
|
||||
|
||||
## [0.1.122] - 2024-04-27
|
||||
|
||||
### Added
|
||||
|
||||
- **🌟 Enhanced RAG Pipeline**: Now with hybrid searching via 'BM25', reranking powered by 'CrossEncoder', and configurable relevance score thresholds.
|
||||
- **🛢️ External Database Support**: Seamlessly connect to custom SQLite or Postgres databases using the 'DATABASE_URL' environment variable.
|
||||
- **🌐 Remote ChromaDB Support**: Introducing the capability to connect to remote ChromaDB servers.
|
||||
- **👨💼 Improved Admin Panel**: Admins can now conveniently check users' chat lists and last active status directly from the admin panel.
|
||||
- **🎨 Splash Screen**: Introducing a loading splash screen for a smoother user experience.
|
||||
- **🌍 Language Support Expansion**: Added support for Bangla (bn-BD), along with enhancements to Chinese, Spanish, and Ukrainian translations.
|
||||
- **💻 Improved LaTeX Rendering Performance**: Enjoy faster rendering times for LaTeX equations.
|
||||
- **🔧 More Environment Variables**: Explore additional environment variables in our documentation (https://docs.openwebui.com), including the 'ENABLE_LITELLM' option to manage memory usage.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔧 Ollama Compatibility**: Resolved errors occurring when Ollama server version isn't an integer, such as SHA builds or RCs.
|
||||
- **🐛 Various OpenAI API Issues**: Addressed several issues related to the OpenAI API.
|
||||
- **🛑 Stop Sequence Issue**: Fixed the problem where the stop sequence with a backslash '\' was not functioning.
|
||||
- **🔤 Font Fallback**: Corrected font fallback issue.
|
||||
|
||||
### Changed
|
||||
|
||||
- **⌨️ Prompt Input Behavior on Mobile**: Enter key prompt submission disabled on mobile devices for improved user experience.
|
||||
|
||||
## [0.1.121] - 2024-04-24
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔧 Translation Issues**: Addressed various translation discrepancies.
|
||||
- **🔒 LiteLLM Security Fix**: Updated LiteLLM version to resolve a security vulnerability.
|
||||
- **🖥️ HTML Tag Display**: Rectified the issue where the '< br >' tag wasn't displaying correctly.
|
||||
- **🔗 WebSocket Connection**: Resolved the failure of WebSocket connection under HTTPS security for ComfyUI server.
|
||||
- **📜 FileReader Optimization**: Implemented FileReader initialization per image in multi-file drag & drop to ensure reusability.
|
||||
- **🏷️ Tag Display**: Corrected tag display inconsistencies.
|
||||
- **📦 Archived Chat Styling**: Fixed styling issues in archived chat.
|
||||
- **🔖 Safari Copy Button Bug**: Addressed the bug where the copy button failed to copy links in Safari.
|
||||
|
||||
## [0.1.120] - 2024-04-20
|
||||
|
||||
### Added
|
||||
|
||||
- **📦 Archive Chat Feature**: Easily archive chats with a new sidebar button, and access archived chats via the profile button > archived chats.
|
||||
- **🔊 Configurable Text-to-Speech Endpoint**: Customize your Text-to-Speech experience with configurable OpenAI endpoints.
|
||||
- **🛠️ Improved Error Handling**: Enhanced error message handling for connection failures.
|
||||
- **⌨️ Enhanced Shortcut**: When editing messages, use ctrl/cmd+enter to save and submit, and esc to close.
|
||||
- **🌐 Language Support**: Added support for Georgian and enhanced translations for Portuguese and Vietnamese.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔧 Model Selector**: Resolved issue where default model selection was not saving.
|
||||
- **🔗 Share Link Copy Button**: Fixed bug where the copy button wasn't copying links in Safari.
|
||||
- **🎨 Light Theme Styling**: Addressed styling issue with the light theme.
|
||||
|
||||
## [0.1.119] - 2024-04-16
|
||||
|
||||
### Added
|
||||
|
||||
- **🌟 Enhanced RAG Embedding Support**: Ollama, and OpenAI models can now be used for RAG embedding model.
|
||||
- **🔄 Seamless Integration**: Copy 'ollama run <model name>' directly from Ollama page to easily select and pull models.
|
||||
- **🏷️ Tagging Feature**: Add tags to chats directly via the sidebar chat menu.
|
||||
- **📱 Mobile Accessibility**: Swipe left and right on mobile to effortlessly open and close the sidebar.
|
||||
- **🔍 Improved Navigation**: Admin panel now supports pagination for user list.
|
||||
- **🌍 Additional Language Support**: Added Polish language support.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🌍 Language Enhancements**: Vietnamese and Spanish translations have been improved.
|
||||
- **🔧 Helm Fixes**: Resolved issues with Helm trailing slash and manifest.json.
|
||||
|
||||
### Changed
|
||||
|
||||
- **🐳 Docker Optimization**: Updated docker image build process to utilize 'uv' for significantly faster builds compared to 'pip3'.
|
||||
|
||||
## [0.1.118] - 2024-04-10
|
||||
|
||||
### Added
|
||||
|
||||
- **🦙 Ollama and CUDA Images**: Added support for ':ollama' and ':cuda' tagged images.
|
||||
- **👍 Enhanced Response Rating**: Now you can annotate your ratings for better feedback.
|
||||
- **👤 User Initials Profile Photo**: User initials are now the default profile photo.
|
||||
- **🔍 Update RAG Embedding Model**: Customize RAG embedding model directly in document settings.
|
||||
- **🌍 Additional Language Support**: Added Turkish language support.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔒 Share Chat Permission**: Resolved issue with chat sharing permissions.
|
||||
- **🛠 Modal Close**: Modals can now be closed using the Esc key.
|
||||
|
||||
### Changed
|
||||
|
||||
- **🎨 Admin Panel Styling**: Refreshed styling for the admin panel.
|
||||
- **🐳 Docker Image Build**: Updated docker image build process for improved efficiency.
|
||||
|
||||
## [0.1.117] - 2024-04-03
|
||||
|
||||
### Added
|
||||
|
|
140
Dockerfile
140
Dockerfile
|
@ -1,82 +1,128 @@
|
|||
# syntax=docker/dockerfile:1
|
||||
# Initialize device type args
|
||||
# use build args in the docker build commmand with --build-arg="BUILDARG=true"
|
||||
ARG USE_CUDA=false
|
||||
ARG USE_OLLAMA=false
|
||||
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
|
||||
ARG USE_CUDA_VER=cu121
|
||||
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
|
||||
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
|
||||
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
|
||||
# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
|
||||
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
||||
ARG USE_RERANKING_MODEL=""
|
||||
|
||||
FROM node:alpine as build
|
||||
######## WebUI frontend ########
|
||||
FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# wget embedding model weight from alpine (does not exist from slim-buster)
|
||||
RUN wget "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz" -O - | \
|
||||
tar -xzf - -C /app
|
||||
|
||||
COPY package.json package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY . .
|
||||
RUN npm run build
|
||||
|
||||
|
||||
######## WebUI backend ########
|
||||
FROM python:3.11-slim-bookworm as base
|
||||
|
||||
ENV ENV=prod
|
||||
ENV PORT ""
|
||||
# Use args
|
||||
ARG USE_CUDA
|
||||
ARG USE_OLLAMA
|
||||
ARG USE_CUDA_VER
|
||||
ARG USE_EMBEDDING_MODEL
|
||||
ARG USE_RERANKING_MODEL
|
||||
|
||||
ENV OLLAMA_BASE_URL "/ollama"
|
||||
## Basis ##
|
||||
ENV ENV=prod \
|
||||
PORT=8080 \
|
||||
# pass build args to the build
|
||||
USE_OLLAMA_DOCKER=${USE_OLLAMA} \
|
||||
USE_CUDA_DOCKER=${USE_CUDA} \
|
||||
USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
|
||||
USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
|
||||
USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
|
||||
|
||||
ENV OPENAI_API_BASE_URL ""
|
||||
ENV OPENAI_API_KEY ""
|
||||
## Basis URL Config ##
|
||||
ENV OLLAMA_BASE_URL="/ollama" \
|
||||
OPENAI_API_BASE_URL=""
|
||||
|
||||
ENV WEBUI_SECRET_KEY ""
|
||||
ENV WEBUI_AUTH_TRUSTED_EMAIL_HEADER ""
|
||||
|
||||
ENV SCARF_NO_ANALYTICS true
|
||||
ENV DO_NOT_TRACK true
|
||||
## API Key and Security Config ##
|
||||
ENV OPENAI_API_KEY="" \
|
||||
WEBUI_SECRET_KEY="" \
|
||||
SCARF_NO_ANALYTICS=true \
|
||||
DO_NOT_TRACK=true \
|
||||
ANONYMIZED_TELEMETRY=false
|
||||
|
||||
# Use locally bundled version of the LiteLLM cost map json
|
||||
# to avoid repetitive startup connections
|
||||
ENV LITELLM_LOCAL_MODEL_COST_MAP="True"
|
||||
|
||||
######## Preloaded models ########
|
||||
# whisper TTS Settings
|
||||
ENV WHISPER_MODEL="base"
|
||||
ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
|
||||
|
||||
# RAG Embedding Model Settings
|
||||
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
|
||||
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
|
||||
# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
|
||||
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
|
||||
ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
|
||||
# device type for whisper tts and embbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
|
||||
ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
|
||||
ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
|
||||
ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR
|
||||
#### Other models #########################################################
|
||||
## whisper TTS model settings ##
|
||||
ENV WHISPER_MODEL="base" \
|
||||
WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
|
||||
|
||||
######## Preloaded models ########
|
||||
## RAG Embedding model settings ##
|
||||
ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
|
||||
RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
|
||||
SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
|
||||
|
||||
## Hugging Face download cache ##
|
||||
ENV HF_HOME="/app/backend/data/cache/embedding/models"
|
||||
#### Other models ##########################################################
|
||||
|
||||
WORKDIR /app/backend
|
||||
|
||||
ENV HOME /root
|
||||
RUN mkdir -p $HOME/.cache/chroma
|
||||
RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
|
||||
|
||||
RUN if [ "$USE_OLLAMA" = "true" ]; then \
|
||||
apt-get update && \
|
||||
# Install pandoc and netcat
|
||||
apt-get install -y --no-install-recommends pandoc netcat-openbsd && \
|
||||
# for RAG OCR
|
||||
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
|
||||
# install helper tools
|
||||
apt-get install -y --no-install-recommends curl && \
|
||||
# install ollama
|
||||
curl -fsSL https://ollama.com/install.sh | sh && \
|
||||
# cleanup
|
||||
rm -rf /var/lib/apt/lists/*; \
|
||||
else \
|
||||
apt-get update && \
|
||||
# Install pandoc and netcat
|
||||
apt-get install -y --no-install-recommends pandoc netcat-openbsd && \
|
||||
# for RAG OCR
|
||||
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
|
||||
# cleanup
|
||||
rm -rf /var/lib/apt/lists/*; \
|
||||
fi
|
||||
|
||||
# install python dependencies
|
||||
COPY ./backend/requirements.txt ./requirements.txt
|
||||
|
||||
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
|
||||
RUN pip3 install uv && \
|
||||
if [ "$USE_CUDA" = "true" ]; then \
|
||||
# If you use CUDA the whisper and embedding model will be downloaded on first use
|
||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
|
||||
uv pip install --system -r requirements.txt --no-cache-dir && \
|
||||
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
|
||||
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
|
||||
else \
|
||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
|
||||
uv pip install --system -r requirements.txt --no-cache-dir && \
|
||||
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
|
||||
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
|
||||
fi
|
||||
|
||||
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir
|
||||
RUN pip3 install -r requirements.txt --no-cache-dir
|
||||
|
||||
# Install pandoc and netcat
|
||||
# RUN python -c "import pypandoc; pypandoc.download_pandoc()"
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y pandoc netcat-openbsd \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# preload embedding model
|
||||
RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])"
|
||||
# preload tts model
|
||||
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
|
||||
|
||||
# copy embedding weight from build
|
||||
RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
|
||||
COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
|
||||
# RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
|
||||
# COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
|
||||
|
||||
# copy built frontend files
|
||||
COPY --from=build /app/build /app/build
|
||||
|
@ -86,4 +132,6 @@ COPY --from=build /app/package.json /app/package.json
|
|||
# copy backend files
|
||||
COPY ./backend .
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
CMD [ "bash", "start.sh"]
|
||||
|
|
22
Makefile
22
Makefile
|
@ -1,27 +1,33 @@
|
|||
|
||||
ifneq ($(shell which docker-compose 2>/dev/null),)
|
||||
DOCKER_COMPOSE := docker-compose
|
||||
else
|
||||
DOCKER_COMPOSE := docker compose
|
||||
endif
|
||||
|
||||
install:
|
||||
@docker-compose up -d
|
||||
$(DOCKER_COMPOSE) up -d
|
||||
|
||||
remove:
|
||||
@chmod +x confirm_remove.sh
|
||||
@./confirm_remove.sh
|
||||
|
||||
|
||||
start:
|
||||
@docker-compose start
|
||||
$(DOCKER_COMPOSE) start
|
||||
startAndBuild:
|
||||
docker-compose up -d --build
|
||||
$(DOCKER_COMPOSE) up -d --build
|
||||
|
||||
stop:
|
||||
@docker-compose stop
|
||||
$(DOCKER_COMPOSE) stop
|
||||
|
||||
update:
|
||||
# Calls the LLM update script
|
||||
chmod +x update_ollama_models.sh
|
||||
@./update_ollama_models.sh
|
||||
@git pull
|
||||
@docker-compose down
|
||||
$(DOCKER_COMPOSE) down
|
||||
# Make sure the ollama-webui container is stopped before rebuilding
|
||||
@docker stop open-webui || true
|
||||
@docker-compose up --build -d
|
||||
@docker-compose start
|
||||
$(DOCKER_COMPOSE) up --build -d
|
||||
$(DOCKER_COMPOSE) start
|
||||
|
||||
|
|
51
README.md
51
README.md
|
@ -25,22 +25,28 @@ Open WebUI is an extensible, feature-rich, and user-friendly self-hosted WebUI d
|
|||
|
||||
- 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience.
|
||||
|
||||
- 🌈 **Theme Customization**: Choose from a variety of themes to personalize your Open WebUI experience.
|
||||
|
||||
- 💻 **Code Syntax Highlighting**: Enjoy enhanced code readability with our syntax highlighting feature.
|
||||
|
||||
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
|
||||
|
||||
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with the groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using `#` command in the prompt. In its alpha phase, occasional issues may arise as we actively refine and enhance this feature to ensure optimal performance and reliability.
|
||||
|
||||
- 🔍 **RAG Embedding Support**: Change the RAG embedding model directly in document settings, enhancing document processing. This feature supports Ollama and OpenAI models.
|
||||
|
||||
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by the URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
|
||||
|
||||
- 📜 **Prompt Preset Support**: Instantly access preset prompts using the `/` command in the chat input. Load predefined conversation starters effortlessly and expedite your interactions. Effortlessly import prompts through [Open WebUI Community](https://openwebui.com/) integration.
|
||||
|
||||
- 👍👎 **RLHF Annotation**: Empower your messages by rating them with thumbs up and thumbs down, facilitating the creation of datasets for Reinforcement Learning from Human Feedback (RLHF). Utilize your messages to train or fine-tune models, all while ensuring the confidentiality of locally saved data.
|
||||
- 👍👎 **RLHF Annotation**: Empower your messages by rating them with thumbs up and thumbs down, followed by the option to provide textual feedback, facilitating the creation of datasets for Reinforcement Learning from Human Feedback (RLHF). Utilize your messages to train or fine-tune models, all while ensuring the confidentiality of locally saved data.
|
||||
|
||||
- 🏷️ **Conversation Tagging**: Effortlessly categorize and locate specific chats for quick reference and streamlined data collection.
|
||||
|
||||
- 📥🗑️ **Download/Delete Models**: Easily download or remove models directly from the web UI.
|
||||
|
||||
- 🔄 **Update All Ollama Models**: Easily update locally installed models all at once with a convenient button, streamlining model management.
|
||||
|
||||
- ⬆️ **GGUF File Model Creation**: Effortlessly create Ollama models by uploading GGUF files directly from the web UI. Streamlined process with options to upload from your machine or download GGUF files from Hugging Face.
|
||||
|
||||
- 🤖 **Multiple Model Support**: Seamlessly switch between different chat models for diverse interactions.
|
||||
|
@ -53,28 +59,42 @@ Open WebUI is an extensible, feature-rich, and user-friendly self-hosted WebUI d
|
|||
|
||||
- 💬 **Collaborative Chat**: Harness the collective intelligence of multiple models by seamlessly orchestrating group conversations. Use the `@` command to specify the model, enabling dynamic and diverse dialogues within your chat interface. Immerse yourself in the collective intelligence woven into your chat environment.
|
||||
|
||||
- 🗨️ **Local Chat Sharing**: Generate and share chat links seamlessly between users, enhancing collaboration and communication.
|
||||
|
||||
- 🔄 **Regeneration History Access**: Easily revisit and explore your entire regeneration history.
|
||||
|
||||
- 📜 **Chat History**: Effortlessly access and manage your conversation history.
|
||||
|
||||
- 📬 **Archive Chats**: Effortlessly store away completed conversations with LLMs for future reference, maintaining a tidy and clutter-free chat interface while allowing for easy retrieval and reference.
|
||||
|
||||
- 📤📥 **Import/Export Chat History**: Seamlessly move your chat data in and out of the platform.
|
||||
|
||||
- 🗣️ **Voice Input Support**: Engage with your model through voice interactions; enjoy the convenience of talking to your model directly. Additionally, explore the option for sending voice input automatically after 3 seconds of silence for a streamlined experience.
|
||||
|
||||
- 🔊 **Configurable Text-to-Speech Endpoint**: Customize your Text-to-Speech experience with configurable OpenAI endpoints.
|
||||
|
||||
- ⚙️ **Fine-Tuned Control with Advanced Parameters**: Gain a deeper level of control by adjusting parameters such as temperature and defining your system prompts to tailor the conversation to your specific preferences and needs.
|
||||
|
||||
- 🎨🤖 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using AUTOMATIC1111 API (local) and DALL-E, enriching your chat experience with dynamic visual content.
|
||||
- 🎨🤖 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API (local), ComfyUI (local), and DALL-E, enriching your chat experience with dynamic visual content.
|
||||
|
||||
- 🤝 **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**.
|
||||
|
||||
- ✨ **Multiple OpenAI-Compatible API Support**: Seamlessly integrate and customize various OpenAI-compatible APIs, enhancing the versatility of your chat interactions.
|
||||
|
||||
- 🔑 **API Key Generation Support**: Generate secret keys to leverage Open WebUI with OpenAI libraries, simplifying integration and development.
|
||||
|
||||
- 🔗 **External Ollama Server Connection**: Seamlessly link to an external Ollama server hosted on a different address by configuring the environment variable.
|
||||
|
||||
- 🔀 **Multiple Ollama Instance Load Balancing**: Effortlessly distribute chat requests across multiple Ollama instances for enhanced performance and reliability.
|
||||
|
||||
- 👥 **Multi-User Management**: Easily oversee and administer users via our intuitive admin panel, streamlining user management processes.
|
||||
|
||||
- 🔗 **Webhook Integration**: Subscribe to new user sign-up events via webhook (compatible with Google Chat and Microsoft Teams), providing real-time notifications and automation capabilities.
|
||||
|
||||
- 🛡️ **Model Whitelisting**: Admins can whitelist models for users with the 'user' role, enhancing security and access control.
|
||||
|
||||
- 📧 **Trusted Email Authentication**: Authenticate using a trusted email header, adding an additional layer of security and authentication.
|
||||
|
||||
- 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
|
||||
|
||||
- 🔒 **Backend Reverse Proxy Support**: Bolster security through direct communication between Open WebUI backend and Ollama. This key feature eliminates the need to expose Ollama over LAN. Requests made to the '/ollama/api' route from the web UI are seamlessly redirected to Ollama from the backend, enhancing overall system security.
|
||||
|
@ -94,24 +114,27 @@ Don't forget to explore our sibling project, [Open WebUI Community](https://open
|
|||
|
||||
### Quick Start with Docker 🐳
|
||||
|
||||
> [!IMPORTANT]
|
||||
> [!WARNING]
|
||||
> When using Docker to install Open WebUI, make sure to include the `-v open-webui:/app/backend/data` in your Docker command. This step is crucial as it ensures your database is properly mounted and prevents any loss of data.
|
||||
|
||||
- **If Ollama is on your computer**, use this command:
|
||||
> [!TIP]
|
||||
> If you wish to utilize Open WebUI with Ollama included or CUDA acceleration, we recommend utilizing our official images tagged with either `:cuda` or `:ollama`. To enable CUDA, you must install the [Nvidia CUDA container toolkit](https://docs.nvidia.com/dgx/nvidia-container-runtime-upgrade/) on your Linux/WSL system.
|
||||
|
||||
```bash
|
||||
docker run -d -p 3000:8080 --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
|
||||
```
|
||||
**If Ollama is on your computer**, use this command:
|
||||
|
||||
- **If Ollama is on a Different Server**, use this command:
|
||||
```bash
|
||||
docker run -d -p 3000:8080 --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
|
||||
```
|
||||
|
||||
- To connect to Ollama on another server, change the `OLLAMA_BASE_URL` to the server's URL:
|
||||
**If Ollama is on a Different Server**, use this command:
|
||||
|
||||
```bash
|
||||
docker run -d -p 3000:8080 -e OLLAMA_BASE_URL=https://example.com -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
|
||||
```
|
||||
To connect to Ollama on another server, change the `OLLAMA_BASE_URL` to the server's URL:
|
||||
|
||||
- After installation, you can access Open WebUI at [http://localhost:3000](http://localhost:3000). Enjoy! 😄
|
||||
```bash
|
||||
docker run -d -p 3000:8080 -e OLLAMA_BASE_URL=https://example.com -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
|
||||
```
|
||||
|
||||
After installation, you can access Open WebUI at [http://localhost:3000](http://localhost:3000). Enjoy! 😄
|
||||
|
||||
#### Open WebUI: Server Connection Error
|
||||
|
||||
|
@ -182,4 +205,4 @@ If you have any questions, suggestions, or need assistance, please open an issue
|
|||
|
||||
---
|
||||
|
||||
Created by [Timothy J. Baek](https://github.com/tjbck) - Let's make Open Web UI even more amazing together! 💪
|
||||
Created by [Timothy J. Baek](https://github.com/tjbck) - Let's make Open WebUI even more amazing together! 💪
|
||||
|
|
|
@ -10,8 +10,19 @@ from fastapi import (
|
|||
File,
|
||||
Form,
|
||||
)
|
||||
|
||||
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from faster_whisper import WhisperModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
import requests
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
|
||||
from constants import ERROR_MESSAGES
|
||||
from utils.utils import (
|
||||
|
@ -28,6 +39,10 @@ from config import (
|
|||
UPLOAD_DIR,
|
||||
WHISPER_MODEL,
|
||||
WHISPER_MODEL_DIR,
|
||||
WHISPER_MODEL_AUTO_UPDATE,
|
||||
DEVICE_TYPE,
|
||||
AUDIO_OPENAI_API_BASE_URL,
|
||||
AUDIO_OPENAI_API_KEY,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -43,7 +58,103 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
|
||||
@app.post("/transcribe")
|
||||
app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
|
||||
app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
|
||||
|
||||
# setting device type for whisper model
|
||||
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
|
||||
log.info(f"whisper_device_type: {whisper_device_type}")
|
||||
|
||||
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class OpenAIConfigUpdateForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
async def get_openai_config(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/config/update")
|
||||
async def update_openai_config(
|
||||
form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
if form_data.key == "":
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
||||
|
||||
app.state.OPENAI_API_BASE_URL = form_data.url
|
||||
app.state.OPENAI_API_KEY = form_data.key
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/speech")
|
||||
async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
body = await request.body()
|
||||
name = hashlib.sha256(body).hexdigest()
|
||||
|
||||
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
|
||||
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
|
||||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
return FileResponse(file_path)
|
||||
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
r = None
|
||||
try:
|
||||
r = requests.post(
|
||||
url=f"{app.state.OPENAI_API_BASE_URL}/audio/speech",
|
||||
data=body,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
# Save the streaming content to a file
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(json.loads(body.decode("utf-8")), f)
|
||||
|
||||
# Return the saved file
|
||||
return FileResponse(file_path)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message']}"
|
||||
except:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r != None else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/transcriptions")
|
||||
def transcribe(
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_current_user),
|
||||
|
@ -64,12 +175,24 @@ def transcribe(
|
|||
f.write(contents)
|
||||
f.close()
|
||||
|
||||
model = WhisperModel(
|
||||
WHISPER_MODEL,
|
||||
device="auto",
|
||||
compute_type="int8",
|
||||
download_root=WHISPER_MODEL_DIR,
|
||||
whisper_kwargs = {
|
||||
"model_size_or_path": WHISPER_MODEL,
|
||||
"device": whisper_device_type,
|
||||
"compute_type": "int8",
|
||||
"download_root": WHISPER_MODEL_DIR,
|
||||
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
|
||||
}
|
||||
|
||||
log.debug(f"whisper_kwargs: {whisper_kwargs}")
|
||||
|
||||
try:
|
||||
model = WhisperModel(**whisper_kwargs)
|
||||
except:
|
||||
log.warning(
|
||||
"WhisperModel initialization failed, attempting download with local_files_only=False"
|
||||
)
|
||||
whisper_kwargs["local_files_only"] = False
|
||||
model = WhisperModel(**whisper_kwargs)
|
||||
|
||||
segments, info = model.transcribe(file_path, beam_size=5)
|
||||
log.info(
|
||||
|
|
|
@ -24,12 +24,25 @@ from utils.misc import calculate_sha256
|
|||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from pathlib import Path
|
||||
import mimetypes
|
||||
import uuid
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from config import SRC_LOG_LEVELS, CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
CACHE_DIR,
|
||||
IMAGE_GENERATION_ENGINE,
|
||||
ENABLE_IMAGE_GENERATION,
|
||||
AUTOMATIC1111_BASE_URL,
|
||||
COMFYUI_BASE_URL,
|
||||
IMAGES_OPENAI_API_BASE_URL,
|
||||
IMAGES_OPENAI_API_KEY,
|
||||
IMAGE_GENERATION_MODEL,
|
||||
IMAGE_SIZE,
|
||||
IMAGE_STEPS,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -47,19 +60,21 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.state.ENGINE = ""
|
||||
app.state.ENABLED = False
|
||||
app.state.ENGINE = IMAGE_GENERATION_ENGINE
|
||||
app.state.ENABLED = ENABLE_IMAGE_GENERATION
|
||||
|
||||
app.state.OPENAI_API_KEY = ""
|
||||
app.state.MODEL = ""
|
||||
app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
|
||||
app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
|
||||
|
||||
app.state.MODEL = IMAGE_GENERATION_MODEL
|
||||
|
||||
|
||||
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
||||
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
||||
|
||||
|
||||
app.state.IMAGE_SIZE = "512x512"
|
||||
app.state.IMAGE_STEPS = 50
|
||||
app.state.IMAGE_SIZE = IMAGE_SIZE
|
||||
app.state.IMAGE_STEPS = IMAGE_STEPS
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
|
@ -125,27 +140,33 @@ async def update_engine_url(
|
|||
}
|
||||
|
||||
|
||||
class OpenAIKeyUpdateForm(BaseModel):
|
||||
class OpenAIConfigUpdateForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
|
||||
|
||||
@app.get("/key")
|
||||
async def get_openai_key(user=Depends(get_admin_user)):
|
||||
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
|
||||
@app.get("/openai/config")
|
||||
async def get_openai_config(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/key/update")
|
||||
async def update_openai_key(
|
||||
form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user)
|
||||
@app.post("/openai/config/update")
|
||||
async def update_openai_config(
|
||||
form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
|
||||
if form_data.key == "":
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
||||
|
||||
app.state.OPENAI_API_BASE_URL = form_data.url
|
||||
app.state.OPENAI_API_KEY = form_data.key
|
||||
|
||||
return {
|
||||
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
|
||||
"status": True,
|
||||
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
|
||||
}
|
||||
|
||||
|
||||
|
@ -295,35 +316,61 @@ class GenerateImageForm(BaseModel):
|
|||
|
||||
|
||||
def save_b64_image(b64_str):
|
||||
image_id = str(uuid.uuid4())
|
||||
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
|
||||
|
||||
try:
|
||||
# Split the base64 string to get the actual image data
|
||||
image_id = str(uuid.uuid4())
|
||||
|
||||
if "," in b64_str:
|
||||
header, encoded = b64_str.split(",", 1)
|
||||
mime_type = header.split(";")[0]
|
||||
|
||||
img_data = base64.b64decode(encoded)
|
||||
image_format = mimetypes.guess_extension(mime_type)
|
||||
|
||||
image_filename = f"{image_id}{image_format}"
|
||||
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(img_data)
|
||||
return image_filename
|
||||
else:
|
||||
image_filename = f"{image_id}.png"
|
||||
file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
|
||||
|
||||
img_data = base64.b64decode(b64_str)
|
||||
|
||||
# Write the image data to a file
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(img_data)
|
||||
return image_filename
|
||||
|
||||
return image_id
|
||||
except Exception as e:
|
||||
log.error(f"Error saving image: {e}")
|
||||
log.exception(f"Error saving image: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def save_url_image(url):
|
||||
image_id = str(uuid.uuid4())
|
||||
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
|
||||
|
||||
try:
|
||||
r = requests.get(url)
|
||||
r.raise_for_status()
|
||||
if r.headers["content-type"].split("/")[0] == "image":
|
||||
|
||||
mime_type = r.headers["content-type"]
|
||||
image_format = mimetypes.guess_extension(mime_type)
|
||||
|
||||
if not image_format:
|
||||
raise ValueError("Could not determine image type from MIME type")
|
||||
|
||||
image_filename = f"{image_id}{image_format}"
|
||||
|
||||
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
|
||||
with open(file_path, "wb") as image_file:
|
||||
image_file.write(r.content)
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
image_file.write(chunk)
|
||||
return image_filename
|
||||
else:
|
||||
log.error(f"Url does not point to an image.")
|
||||
return None
|
||||
|
||||
return image_id
|
||||
except Exception as e:
|
||||
log.exception(f"Error saving image: {e}")
|
||||
return None
|
||||
|
@ -354,7 +401,7 @@ def generate_image(
|
|||
}
|
||||
|
||||
r = requests.post(
|
||||
url=f"https://api.openai.com/v1/images/generations",
|
||||
url=f"{app.state.OPENAI_API_BASE_URL}/images/generations",
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
@ -365,9 +412,9 @@ def generate_image(
|
|||
images = []
|
||||
|
||||
for image in res["data"]:
|
||||
image_id = save_b64_image(image["b64_json"])
|
||||
images.append({"url": f"/cache/image/generations/{image_id}.png"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
|
||||
image_filename = save_b64_image(image["b64_json"])
|
||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
@ -402,9 +449,9 @@ def generate_image(
|
|||
images = []
|
||||
|
||||
for image in res["data"]:
|
||||
image_id = save_url_image(image["url"])
|
||||
images.append({"url": f"/cache/image/generations/{image_id}.png"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
|
||||
image_filename = save_url_image(image["url"])
|
||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(data.model_dump(exclude_none=True), f)
|
||||
|
@ -440,9 +487,9 @@ def generate_image(
|
|||
images = []
|
||||
|
||||
for image in res["images"]:
|
||||
image_id = save_b64_image(image)
|
||||
images.append({"url": f"/cache/image/generations/{image_id}.png"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
|
||||
image_filename = save_b64_image(image)
|
||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump({**data, "info": res["info"]}, f)
|
||||
|
|
|
@ -195,7 +195,7 @@ class ImageGenerationPayload(BaseModel):
|
|||
def comfyui_generate_image(
|
||||
model: str, payload: ImageGenerationPayload, client_id, base_url
|
||||
):
|
||||
host = base_url.replace("http://", "").replace("https://", "")
|
||||
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
|
||||
comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
|
||||
|
||||
|
@ -217,7 +217,7 @@ def comfyui_generate_image(
|
|||
|
||||
try:
|
||||
ws = websocket.WebSocket()
|
||||
ws.connect(f"ws://{host}/ws?clientId={client_id}")
|
||||
ws.connect(f"{ws_url}/ws?clientId={client_id}")
|
||||
log.info("WebSocket connection established.")
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to connect to WebSocket server: {e}")
|
||||
|
|
|
@ -1,100 +1,372 @@
|
|||
import sys
|
||||
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
import logging
|
||||
|
||||
from litellm.proxy.proxy_server import ProxyConfig, initialize
|
||||
from litellm.proxy.proxy_server import app
|
||||
|
||||
from fastapi import FastAPI, Request, Depends, status, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.responses import StreamingResponse
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
|
||||
from utils.utils import get_http_authorization_cred, get_current_user
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Optional, List
|
||||
|
||||
from utils.utils import get_verified_user, get_current_user, get_admin_user
|
||||
from config import SRC_LOG_LEVELS, ENV
|
||||
from constants import MESSAGES
|
||||
|
||||
import os
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
|
||||
|
||||
|
||||
from config import (
|
||||
MODEL_FILTER_ENABLED,
|
||||
ENABLE_LITELLM,
|
||||
ENABLE_MODEL_FILTER,
|
||||
MODEL_FILTER_LIST,
|
||||
DATA_DIR,
|
||||
LITELLM_PROXY_PORT,
|
||||
LITELLM_PROXY_HOST,
|
||||
)
|
||||
|
||||
from litellm.utils import get_llm_provider
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import yaml
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
origins = ["*"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
proxy_config = ProxyConfig()
|
||||
LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml"
|
||||
|
||||
with open(LITELLM_CONFIG_DIR, "r") as file:
|
||||
litellm_config = yaml.safe_load(file)
|
||||
|
||||
|
||||
async def config():
|
||||
router, model_list, general_settings = await proxy_config.load_config(
|
||||
router=None, config_file_path="./data/litellm/config.yaml"
|
||||
app.state.ENABLE = ENABLE_LITELLM
|
||||
app.state.CONFIG = litellm_config
|
||||
|
||||
# Global variable to store the subprocess reference
|
||||
background_process = None
|
||||
|
||||
CONFLICT_ENV_VARS = [
|
||||
# Uvicorn uses PORT, so LiteLLM might use it as well
|
||||
"PORT",
|
||||
# LiteLLM uses DATABASE_URL for Prisma connections
|
||||
"DATABASE_URL",
|
||||
]
|
||||
|
||||
|
||||
async def run_background_process(command):
|
||||
global background_process
|
||||
log.info("run_background_process")
|
||||
|
||||
try:
|
||||
# Log the command to be executed
|
||||
log.info(f"Executing command: {command}")
|
||||
# Filter environment variables known to conflict with litellm
|
||||
env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS}
|
||||
# Execute the command and create a subprocess
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
|
||||
)
|
||||
background_process = process
|
||||
log.info("Subprocess started successfully.")
|
||||
|
||||
await initialize(config="./data/litellm/config.yaml", telemetry=False)
|
||||
# Capture STDERR for debugging purposes
|
||||
stderr_output = await process.stderr.read()
|
||||
stderr_text = stderr_output.decode().strip()
|
||||
if stderr_text:
|
||||
log.info(f"Subprocess STDERR: {stderr_text}")
|
||||
|
||||
# log.info output line by line
|
||||
async for line in process.stdout:
|
||||
log.info(line.decode().strip())
|
||||
|
||||
# Wait for the process to finish
|
||||
returncode = await process.wait()
|
||||
log.info(f"Subprocess exited with return code {returncode}")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to start subprocess: {e}")
|
||||
raise # Optionally re-raise the exception if you want it to propagate
|
||||
|
||||
|
||||
async def startup():
|
||||
await config()
|
||||
async def start_litellm_background():
|
||||
log.info("start_litellm_background")
|
||||
# Command to run in the background
|
||||
command = [
|
||||
"litellm",
|
||||
"--port",
|
||||
str(LITELLM_PROXY_PORT),
|
||||
"--host",
|
||||
LITELLM_PROXY_HOST,
|
||||
"--telemetry",
|
||||
"False",
|
||||
"--config",
|
||||
LITELLM_CONFIG_DIR,
|
||||
]
|
||||
|
||||
await run_background_process(command)
|
||||
|
||||
|
||||
async def shutdown_litellm_background():
|
||||
log.info("shutdown_litellm_background")
|
||||
global background_process
|
||||
if background_process:
|
||||
background_process.terminate()
|
||||
await background_process.wait() # Ensure the process has terminated
|
||||
log.info("Subprocess terminated")
|
||||
background_process = None
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
await startup()
|
||||
async def startup_event():
|
||||
log.info("startup_event")
|
||||
# TODO: Check config.yaml file and create one
|
||||
asyncio.create_task(start_litellm_background())
|
||||
|
||||
|
||||
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def auth_middleware(request: Request, call_next):
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
request.state.user = None
|
||||
@app.get("/")
|
||||
async def get_status():
|
||||
return {"status": True}
|
||||
|
||||
|
||||
async def restart_litellm():
|
||||
"""
|
||||
Endpoint to restart the litellm background service.
|
||||
"""
|
||||
log.info("Requested restart of litellm service.")
|
||||
try:
|
||||
user = get_current_user(get_http_authorization_cred(auth_header))
|
||||
log.debug(f"user: {user}")
|
||||
request.state.user = user
|
||||
# Shut down the existing process if it is running
|
||||
await shutdown_litellm_background()
|
||||
log.info("litellm service shutdown complete.")
|
||||
|
||||
# Restart the background service
|
||||
|
||||
asyncio.create_task(start_litellm_background())
|
||||
log.info("litellm service restart complete.")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "litellm service restarted successfully.",
|
||||
}
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"detail": str(e)})
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
log.info(f"Error restarting litellm service: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
@app.get("/restart")
|
||||
async def restart_litellm_handler(user=Depends(get_admin_user)):
|
||||
return await restart_litellm()
|
||||
|
||||
response = await call_next(request)
|
||||
user = request.state.user
|
||||
|
||||
if "/models" in request.url.path:
|
||||
if isinstance(response, StreamingResponse):
|
||||
# Read the content of the streaming response
|
||||
body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body += chunk
|
||||
@app.get("/config")
|
||||
async def get_config(user=Depends(get_admin_user)):
|
||||
return app.state.CONFIG
|
||||
|
||||
data = json.loads(body.decode("utf-8"))
|
||||
|
||||
if app.state.MODEL_FILTER_ENABLED:
|
||||
class LiteLLMConfigForm(BaseModel):
|
||||
general_settings: Optional[dict] = None
|
||||
litellm_settings: Optional[dict] = None
|
||||
model_list: Optional[List[dict]] = None
|
||||
router_settings: Optional[dict] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@app.post("/config/update")
|
||||
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
|
||||
app.state.CONFIG = form_data.model_dump(exclude_none=True)
|
||||
|
||||
with open(LITELLM_CONFIG_DIR, "w") as file:
|
||||
yaml.dump(app.state.CONFIG, file)
|
||||
|
||||
await restart_litellm()
|
||||
return app.state.CONFIG
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
@app.get("/v1/models")
|
||||
async def get_models(user=Depends(get_current_user)):
|
||||
|
||||
if app.state.ENABLE:
|
||||
while not background_process:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
url = f"http://localhost:{LITELLM_PROXY_PORT}/v1"
|
||||
r = None
|
||||
try:
|
||||
r = requests.request(method="GET", url=f"{url}/models")
|
||||
r.raise_for_status()
|
||||
|
||||
data = r.json()
|
||||
|
||||
if app.state.ENABLE_MODEL_FILTER:
|
||||
if user and user.role == "user":
|
||||
data["data"] = list(
|
||||
filter(
|
||||
lambda model: model["id"]
|
||||
in app.state.MODEL_FILTER_LIST,
|
||||
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
|
||||
data["data"],
|
||||
)
|
||||
)
|
||||
|
||||
# Modified Flag
|
||||
data["modified"] = True
|
||||
return JSONResponse(content=data)
|
||||
return data
|
||||
except Exception as e:
|
||||
|
||||
return response
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']}"
|
||||
except:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"id": model["model_name"],
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "openai",
|
||||
}
|
||||
for model in app.state.CONFIG["model_list"]
|
||||
],
|
||||
"object": "list",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"data": [],
|
||||
"object": "list",
|
||||
}
|
||||
|
||||
|
||||
app.add_middleware(ModifyModelsResponseMiddleware)
|
||||
@app.get("/model/info")
|
||||
async def get_model_list(user=Depends(get_admin_user)):
|
||||
return {"data": app.state.CONFIG["model_list"]}
|
||||
|
||||
|
||||
class AddLiteLLMModelForm(BaseModel):
|
||||
model_name: str
|
||||
litellm_params: dict
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@app.post("/model/new")
|
||||
async def add_model_to_config(
|
||||
form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
|
||||
):
|
||||
try:
|
||||
get_llm_provider(model=form_data.model_name)
|
||||
app.state.CONFIG["model_list"].append(form_data.model_dump())
|
||||
|
||||
with open(LITELLM_CONFIG_DIR, "w") as file:
|
||||
yaml.dump(app.state.CONFIG, file)
|
||||
|
||||
await restart_litellm()
|
||||
|
||||
return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
class DeleteLiteLLMModelForm(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@app.post("/model/delete")
|
||||
async def delete_model_from_config(
|
||||
form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user)
|
||||
):
|
||||
app.state.CONFIG["model_list"] = [
|
||||
model
|
||||
for model in app.state.CONFIG["model_list"]
|
||||
if model["model_name"] != form_data.id
|
||||
]
|
||||
|
||||
with open(LITELLM_CONFIG_DIR, "w") as file:
|
||||
yaml.dump(app.state.CONFIG, file)
|
||||
|
||||
await restart_litellm()
|
||||
|
||||
return {"message": MESSAGES.MODEL_DELETED(form_data.id)}
|
||||
|
||||
|
||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
body = await request.body()
|
||||
|
||||
url = f"http://localhost:{LITELLM_PROXY_PORT}"
|
||||
|
||||
target_url = f"{url}/{path}"
|
||||
|
||||
headers = {}
|
||||
# headers["Authorization"] = f"Bearer {key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
r = None
|
||||
|
||||
try:
|
||||
r = requests.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
data=body,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
# Check if response is SSE
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
return StreamingResponse(
|
||||
r.iter_content(chunk_size=8192),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
else:
|
||||
response_data = r.json()
|
||||
return response_data
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500, detail=error_detail
|
||||
)
|
||||
|
|
|
@ -16,6 +16,7 @@ from fastapi.concurrency import run_in_threadpool
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
import os
|
||||
import re
|
||||
import copy
|
||||
import random
|
||||
import requests
|
||||
|
@ -36,7 +37,7 @@ from utils.utils import decode_token, get_current_user, get_admin_user
|
|||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
OLLAMA_BASE_URLS,
|
||||
MODEL_FILTER_ENABLED,
|
||||
ENABLE_MODEL_FILTER,
|
||||
MODEL_FILTER_LIST,
|
||||
UPLOAD_DIR,
|
||||
)
|
||||
|
@ -55,7 +56,7 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
|
||||
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
||||
|
@ -168,7 +169,7 @@ async def get_ollama_tags(
|
|||
if url_idx == None:
|
||||
models = await get_all_models()
|
||||
|
||||
if app.state.MODEL_FILTER_ENABLED:
|
||||
if app.state.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models["models"] = list(
|
||||
filter(
|
||||
|
@ -215,7 +216,10 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
|
|||
|
||||
if len(responses) > 0:
|
||||
lowest_version = min(
|
||||
responses, key=lambda x: tuple(map(int, x["version"].split(".")))
|
||||
responses,
|
||||
key=lambda x: tuple(
|
||||
map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
|
||||
),
|
||||
)
|
||||
|
||||
return {"version": lowest_version["version"]}
|
||||
|
@ -611,8 +615,13 @@ async def generate_embeddings(
|
|||
user=Depends(get_current_user),
|
||||
):
|
||||
if url_idx == None:
|
||||
if form_data.model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
@ -648,6 +657,60 @@ async def generate_embeddings(
|
|||
)
|
||||
|
||||
|
||||
def generate_ollama_embeddings(
|
||||
form_data: GenerateEmbeddingsForm,
|
||||
url_idx: Optional[int] = None,
|
||||
):
|
||||
|
||||
log.info(f"generate_ollama_embeddings {form_data}")
|
||||
|
||||
if url_idx == None:
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||
)
|
||||
|
||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
try:
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/embeddings",
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
data = r.json()
|
||||
|
||||
log.info(f"generate_ollama_embeddings {data}")
|
||||
|
||||
if "embedding" in data:
|
||||
return data["embedding"]
|
||||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise error_detail
|
||||
|
||||
|
||||
class GenerateCompletionForm(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
|
@ -671,8 +734,13 @@ async def generate_completion(
|
|||
):
|
||||
|
||||
if url_idx == None:
|
||||
if form_data.model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
@ -769,8 +837,13 @@ async def generate_chat_completion(
|
|||
):
|
||||
|
||||
if url_idx == None:
|
||||
if form_data.model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
@ -873,8 +946,13 @@ async def generate_openai_chat_completion(
|
|||
):
|
||||
|
||||
if url_idx == None:
|
||||
if form_data.model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
|
|
@ -24,7 +24,7 @@ from config import (
|
|||
OPENAI_API_BASE_URLS,
|
||||
OPENAI_API_KEYS,
|
||||
CACHE_DIR,
|
||||
MODEL_FILTER_ENABLED,
|
||||
ENABLE_MODEL_FILTER,
|
||||
MODEL_FILTER_LIST,
|
||||
)
|
||||
from typing import List, Optional
|
||||
|
@ -45,7 +45,7 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
||||
|
@ -80,6 +80,7 @@ async def get_openai_urls(user=Depends(get_admin_user)):
|
|||
|
||||
@app.post("/urls/update")
|
||||
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
|
||||
await get_all_models()
|
||||
app.state.OPENAI_API_BASE_URLS = form_data.urls
|
||||
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
|
||||
|
||||
|
@ -170,6 +171,7 @@ async def fetch_url(url, key):
|
|||
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
log.info(f"merge_models_lists {model_lists}")
|
||||
merged_list = []
|
||||
|
||||
for idx, models in enumerate(model_lists):
|
||||
|
@ -198,14 +200,16 @@ async def get_all_models():
|
|||
]
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
log.info(f"get_all_models:responses() {responses}")
|
||||
|
||||
models = {
|
||||
"data": merge_models_lists(
|
||||
list(
|
||||
map(
|
||||
lambda response: (
|
||||
response["data"]
|
||||
if response and "data" in response
|
||||
else None
|
||||
if (response and "data" in response)
|
||||
else (response if isinstance(response, list) else None)
|
||||
),
|
||||
responses,
|
||||
)
|
||||
|
@ -224,7 +228,7 @@ async def get_all_models():
|
|||
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
|
||||
if url_idx == None:
|
||||
models = await get_all_models()
|
||||
if app.state.MODEL_FILTER_ENABLED:
|
||||
if app.state.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models["data"] = list(
|
||||
filter(
|
||||
|
@ -341,7 +345,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']}"
|
||||
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
|
|
|
@ -13,8 +13,7 @@ import os, shutil, logging, re
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from chromadb.utils import embedding_functions
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
WebBaseLoader,
|
||||
|
@ -29,15 +28,22 @@ from langchain_community.document_loaders import (
|
|||
UnstructuredXMLLoader,
|
||||
UnstructuredRSTLoader,
|
||||
UnstructuredExcelLoader,
|
||||
YoutubeLoader,
|
||||
)
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
import validators
|
||||
import urllib.parse
|
||||
import socket
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import mimetypes
|
||||
import uuid
|
||||
import json
|
||||
|
||||
import sentence_transformers
|
||||
|
||||
from apps.web.models.documents import (
|
||||
Documents,
|
||||
|
@ -45,7 +51,14 @@ from apps.web.models.documents import (
|
|||
DocumentResponse,
|
||||
)
|
||||
|
||||
from apps.rag.utils import query_doc, query_collection
|
||||
from apps.rag.utils import (
|
||||
get_model_path,
|
||||
get_embedding_function,
|
||||
query_doc,
|
||||
query_doc_with_hybrid_search,
|
||||
query_collection,
|
||||
query_collection_with_hybrid_search,
|
||||
)
|
||||
|
||||
from utils.misc import (
|
||||
calculate_sha256,
|
||||
|
@ -54,16 +67,30 @@ from utils.misc import (
|
|||
extract_folders_after_data_docs,
|
||||
)
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
UPLOAD_DIR,
|
||||
DOCS_DIR,
|
||||
RAG_TOP_K,
|
||||
RAG_RELEVANCE_THRESHOLD,
|
||||
RAG_EMBEDDING_ENGINE,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
ENABLE_RAG_HYBRID_SEARCH,
|
||||
RAG_RERANKING_MODEL,
|
||||
PDF_EXTRACT_IMAGES,
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||
RAG_OPENAI_API_BASE_URL,
|
||||
RAG_OPENAI_API_KEY,
|
||||
DEVICE_TYPE,
|
||||
CHROMA_CLIENT,
|
||||
CHUNK_SIZE,
|
||||
CHUNK_OVERLAP,
|
||||
RAG_TEMPLATE,
|
||||
ENABLE_LOCAL_WEB_FETCH,
|
||||
)
|
||||
|
||||
from constants import ERROR_MESSAGES
|
||||
|
@ -71,34 +98,77 @@ from constants import ERROR_MESSAGES
|
|||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
#
|
||||
# if RAG_EMBEDDING_MODEL:
|
||||
# sentence_transformer_ef = SentenceTransformer(
|
||||
# model_name_or_path=RAG_EMBEDDING_MODEL,
|
||||
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
|
||||
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
# )
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.state.PDF_EXTRACT_IMAGES = False
|
||||
app.state.TOP_K = RAG_TOP_K
|
||||
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
||||
|
||||
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
||||
|
||||
app.state.CHUNK_SIZE = CHUNK_SIZE
|
||||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.TOP_K = 4
|
||||
|
||||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
|
||||
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
||||
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
||||
|
||||
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
|
||||
|
||||
|
||||
def update_embedding_model(
|
||||
embedding_model: str,
|
||||
update_model: bool = False,
|
||||
):
|
||||
if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
||||
get_model_path(embedding_model, update_model),
|
||||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
)
|
||||
else:
|
||||
app.state.sentence_transformer_ef = None
|
||||
|
||||
|
||||
def update_reranking_model(
|
||||
reranking_model: str,
|
||||
update_model: bool = False,
|
||||
):
|
||||
if reranking_model:
|
||||
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
||||
get_model_path(reranking_model, update_model),
|
||||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||
)
|
||||
else:
|
||||
app.state.sentence_transformer_rf = None
|
||||
|
||||
|
||||
update_embedding_model(
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
|
||||
update_reranking_model(
|
||||
app.state.RAG_RERANKING_MODEL,
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.RAG_EMBEDDING_ENGINE,
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.OPENAI_API_KEY,
|
||||
app.state.OPENAI_API_BASE_URL,
|
||||
)
|
||||
|
||||
origins = ["*"]
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
|
@ -112,7 +182,7 @@ class CollectionNameForm(BaseModel):
|
|||
collection_name: Optional[str] = "test"
|
||||
|
||||
|
||||
class StoreWebForm(CollectionNameForm):
|
||||
class UrlForm(CollectionNameForm):
|
||||
url: str
|
||||
|
||||
|
||||
|
@ -123,38 +193,110 @@ async def get_status():
|
|||
"chunk_size": app.state.CHUNK_SIZE,
|
||||
"chunk_overlap": app.state.CHUNK_OVERLAP,
|
||||
"template": app.state.RAG_TEMPLATE,
|
||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"reranking_model": app.state.RAG_RERANKING_MODEL,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/embedding/model")
|
||||
async def get_embedding_model(user=Depends(get_admin_user)):
|
||||
@app.get("/embedding")
|
||||
async def get_embedding_config(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"openai_config": {
|
||||
"url": app.state.OPENAI_API_BASE_URL,
|
||||
"key": app.state.OPENAI_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/reranking")
|
||||
async def get_reraanking_config(user=Depends(get_admin_user)):
|
||||
return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
|
||||
|
||||
|
||||
class OpenAIConfigForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
|
||||
|
||||
class EmbeddingModelUpdateForm(BaseModel):
|
||||
openai_config: Optional[OpenAIConfigForm] = None
|
||||
embedding_engine: str
|
||||
embedding_model: str
|
||||
|
||||
|
||||
@app.post("/embedding/model/update")
|
||||
async def update_embedding_model(
|
||||
@app.post("/embedding/update")
|
||||
async def update_embedding_config(
|
||||
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
log.info(
|
||||
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
||||
)
|
||||
try:
|
||||
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
|
||||
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
||||
if form_data.openai_config != None:
|
||||
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
app.state.OPENAI_API_KEY = form_data.openai_config.key
|
||||
|
||||
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.RAG_EMBEDDING_ENGINE,
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.OPENAI_API_KEY,
|
||||
app.state.OPENAI_API_BASE_URL,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"openai_config": {
|
||||
"url": app.state.OPENAI_API_BASE_URL,
|
||||
"key": app.state.OPENAI_API_KEY,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(f"Problem updating embedding model: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
class RerankingModelUpdateForm(BaseModel):
|
||||
reranking_model: str
|
||||
|
||||
|
||||
@app.post("/reranking/update")
|
||||
async def update_reranking_config(
|
||||
form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
log.info(
|
||||
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
|
||||
)
|
||||
try:
|
||||
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
|
||||
|
||||
update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"reranking_model": app.state.RAG_RERANKING_MODEL,
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(f"Problem updating reranking model: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
|
@ -209,12 +351,16 @@ async def get_query_settings(user=Depends(get_admin_user)):
|
|||
"status": True,
|
||||
"template": app.state.RAG_TEMPLATE,
|
||||
"k": app.state.TOP_K,
|
||||
"r": app.state.RELEVANCE_THRESHOLD,
|
||||
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
|
||||
}
|
||||
|
||||
|
||||
class QuerySettingsForm(BaseModel):
|
||||
k: Optional[int] = None
|
||||
r: Optional[float] = None
|
||||
template: Optional[str] = None
|
||||
hybrid: Optional[bool] = None
|
||||
|
||||
|
||||
@app.post("/query/settings/update")
|
||||
|
@ -223,13 +369,23 @@ async def update_query_settings(
|
|||
):
|
||||
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
|
||||
app.state.TOP_K = form_data.k if form_data.k else 4
|
||||
return {"status": True, "template": app.state.RAG_TEMPLATE}
|
||||
app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
|
||||
app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
|
||||
return {
|
||||
"status": True,
|
||||
"template": app.state.RAG_TEMPLATE,
|
||||
"k": app.state.TOP_K,
|
||||
"r": app.state.RELEVANCE_THRESHOLD,
|
||||
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
|
||||
}
|
||||
|
||||
|
||||
class QueryDocForm(BaseModel):
|
||||
collection_name: str
|
||||
query: str
|
||||
k: Optional[int] = None
|
||||
r: Optional[float] = None
|
||||
hybrid: Optional[bool] = None
|
||||
|
||||
|
||||
@app.post("/query/doc")
|
||||
|
@ -237,13 +393,22 @@ def query_doc_handler(
|
|||
form_data: QueryDocForm,
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
if app.state.ENABLE_RAG_HYBRID_SEARCH:
|
||||
return query_doc_with_hybrid_search(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
embedding_function=app.state.EMBEDDING_FUNCTION,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
reranking_function=app.state.sentence_transformer_rf,
|
||||
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
||||
)
|
||||
else:
|
||||
return query_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
embedding_function=app.state.EMBEDDING_FUNCTION,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
@ -257,6 +422,8 @@ class QueryCollectionsForm(BaseModel):
|
|||
collection_names: List[str]
|
||||
query: str
|
||||
k: Optional[int] = None
|
||||
r: Optional[float] = None
|
||||
hybrid: Optional[bool] = None
|
||||
|
||||
|
||||
@app.post("/query/collection")
|
||||
|
@ -264,19 +431,36 @@ def query_collection_handler(
|
|||
form_data: QueryCollectionsForm,
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
if app.state.ENABLE_RAG_HYBRID_SEARCH:
|
||||
return query_collection_with_hybrid_search(
|
||||
collection_names=form_data.collection_names,
|
||||
query=form_data.query,
|
||||
embedding_function=app.state.EMBEDDING_FUNCTION,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
reranking_function=app.state.sentence_transformer_rf,
|
||||
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
||||
)
|
||||
else:
|
||||
return query_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
query=form_data.query,
|
||||
embedding_function=app.state.EMBEDDING_FUNCTION,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/web")
|
||||
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
||||
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
||||
@app.post("/youtube")
|
||||
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
|
||||
try:
|
||||
loader = WebBaseLoader(form_data.url)
|
||||
loader = YoutubeLoader.from_youtube_url(form_data.url, add_video_info=False)
|
||||
data = loader.load()
|
||||
|
||||
collection_name = form_data.collection_name
|
||||
|
@ -297,6 +481,62 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
|||
)
|
||||
|
||||
|
||||
@app.post("/web")
|
||||
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
|
||||
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
||||
try:
|
||||
loader = get_web_loader(form_data.url)
|
||||
data = loader.load()
|
||||
|
||||
collection_name = form_data.collection_name
|
||||
if collection_name == "":
|
||||
collection_name = calculate_sha256_string(form_data.url)[:63]
|
||||
|
||||
store_data_in_vector_db(data, collection_name, overwrite=True)
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filename": form_data.url,
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
def get_web_loader(url: str):
|
||||
# Check if the URL is valid
|
||||
if isinstance(validators.url(url), validators.ValidationError):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
if not ENABLE_LOCAL_WEB_FETCH:
|
||||
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
# Get IPv4 and IPv6 addresses
|
||||
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
|
||||
# Check if any of the resolved addresses are private
|
||||
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
|
||||
for ip in ipv4_addresses:
|
||||
if validators.ipv4(ip, private=True):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
for ip in ipv6_addresses:
|
||||
if validators.ipv6(ip, private=True):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
return WebBaseLoader(url)
|
||||
|
||||
|
||||
def resolve_hostname(hostname):
|
||||
# Get address information
|
||||
addr_info = socket.getaddrinfo(hostname, None)
|
||||
|
||||
# Extract IP addresses from address information
|
||||
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
|
||||
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
|
||||
|
||||
return ipv4_addresses, ipv6_addresses
|
||||
|
||||
|
||||
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
|
@ -304,9 +544,11 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
|
|||
chunk_overlap=app.state.CHUNK_OVERLAP,
|
||||
add_start_index=True,
|
||||
)
|
||||
|
||||
docs = text_splitter.split_documents(data)
|
||||
|
||||
if len(docs) > 0:
|
||||
log.info(f"store_data_in_vector_db {docs}")
|
||||
return store_docs_in_vector_db(docs, collection_name, overwrite), None
|
||||
else:
|
||||
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
|
||||
|
@ -325,6 +567,7 @@ def store_text_in_vector_db(
|
|||
|
||||
|
||||
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
|
||||
log.info(f"store_docs_in_vector_db {docs} {collection_name}")
|
||||
|
||||
texts = [doc.page_content for doc in docs]
|
||||
metadatas = [doc.metadata for doc in docs]
|
||||
|
@ -336,14 +579,28 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|||
log.info(f"deleting existing collection {collection_name}")
|
||||
CHROMA_CLIENT.delete_collection(name=collection_name)
|
||||
|
||||
collection = CHROMA_CLIENT.create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||
|
||||
embedding_func = get_embedding_function(
|
||||
app.state.RAG_EMBEDDING_ENGINE,
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.OPENAI_API_KEY,
|
||||
app.state.OPENAI_API_BASE_URL,
|
||||
)
|
||||
|
||||
collection.add(
|
||||
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
|
||||
)
|
||||
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
embeddings = embedding_func(embedding_texts)
|
||||
|
||||
for batch in create_batches(
|
||||
api=CHROMA_CLIENT,
|
||||
ids=[str(uuid.uuid1()) for _ in texts],
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings,
|
||||
documents=texts,
|
||||
):
|
||||
collection.add(*batch)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
|
@ -1,97 +1,189 @@
|
|||
import re
|
||||
import os
|
||||
import logging
|
||||
import requests
|
||||
|
||||
from typing import List
|
||||
|
||||
from apps.ollama.main import (
|
||||
generate_ollama_embeddings,
|
||||
GenerateEmbeddingsForm,
|
||||
)
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain.retrievers import (
|
||||
ContextualCompressionRetriever,
|
||||
EnsembleRetriever,
|
||||
)
|
||||
|
||||
from typing import Optional
|
||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def query_doc(collection_name: str, query: str, k: int, embedding_function):
|
||||
def query_doc(
|
||||
collection_name: str,
|
||||
query: str,
|
||||
embedding_function,
|
||||
k: int,
|
||||
):
|
||||
try:
|
||||
# if you use docker use the model from the environment variable
|
||||
collection = CHROMA_CLIENT.get_collection(
|
||||
name=collection_name,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
||||
query_embeddings = embedding_function(query)
|
||||
|
||||
result = collection.query(
|
||||
query_texts=[query],
|
||||
query_embeddings=[query_embeddings],
|
||||
n_results=k,
|
||||
)
|
||||
|
||||
log.info(f"query_doc:result {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def merge_and_sort_query_results(query_results, k):
|
||||
# Initialize lists to store combined data
|
||||
combined_ids = []
|
||||
combined_distances = []
|
||||
combined_metadatas = []
|
||||
combined_documents = []
|
||||
def query_doc_with_hybrid_search(
|
||||
collection_name: str,
|
||||
query: str,
|
||||
embedding_function,
|
||||
k: int,
|
||||
reranking_function,
|
||||
r: float,
|
||||
):
|
||||
try:
|
||||
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
||||
documents = collection.get() # get all documents
|
||||
|
||||
# Combine data from each dictionary
|
||||
for data in query_results:
|
||||
combined_ids.extend(data["ids"][0])
|
||||
combined_distances.extend(data["distances"][0])
|
||||
combined_metadatas.extend(data["metadatas"][0])
|
||||
combined_documents.extend(data["documents"][0])
|
||||
bm25_retriever = BM25Retriever.from_texts(
|
||||
texts=documents.get("documents"),
|
||||
metadatas=documents.get("metadatas"),
|
||||
)
|
||||
bm25_retriever.k = k
|
||||
|
||||
# Create a list of tuples (distance, id, metadata, document)
|
||||
combined = list(
|
||||
zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
|
||||
chroma_retriever = ChromaRetriever(
|
||||
collection=collection,
|
||||
embedding_function=embedding_function,
|
||||
top_n=k,
|
||||
)
|
||||
|
||||
# Sort the list based on distances
|
||||
combined.sort(key=lambda x: x[0])
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
|
||||
)
|
||||
|
||||
compressor = RerankCompressor(
|
||||
embedding_function=embedding_function,
|
||||
top_n=k,
|
||||
reranking_function=reranking_function,
|
||||
r_score=r,
|
||||
)
|
||||
|
||||
compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=compressor, base_retriever=ensemble_retriever
|
||||
)
|
||||
|
||||
result = compression_retriever.invoke(query)
|
||||
result = {
|
||||
"distances": [[d.metadata.get("score") for d in result]],
|
||||
"documents": [[d.page_content for d in result]],
|
||||
"metadatas": [[d.metadata for d in result]],
|
||||
}
|
||||
|
||||
log.info(f"query_doc_with_hybrid_search:result {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def merge_and_sort_query_results(query_results, k, reverse=False):
|
||||
# Initialize lists to store combined data
|
||||
combined_distances = []
|
||||
combined_documents = []
|
||||
combined_metadatas = []
|
||||
|
||||
for data in query_results:
|
||||
combined_distances.extend(data["distances"][0])
|
||||
combined_documents.extend(data["documents"][0])
|
||||
combined_metadatas.extend(data["metadatas"][0])
|
||||
|
||||
# Create a list of tuples (distance, document, metadata)
|
||||
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
|
||||
|
||||
# Sort the list based on distances
|
||||
combined.sort(key=lambda x: x[0], reverse=reverse)
|
||||
|
||||
# We don't have anything :-(
|
||||
if not combined:
|
||||
sorted_distances = []
|
||||
sorted_documents = []
|
||||
sorted_metadatas = []
|
||||
else:
|
||||
# Unzip the sorted list
|
||||
sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
|
||||
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
|
||||
|
||||
# Slicing the lists to include only k elements
|
||||
sorted_distances = list(sorted_distances)[:k]
|
||||
sorted_ids = list(sorted_ids)[:k]
|
||||
sorted_metadatas = list(sorted_metadatas)[:k]
|
||||
sorted_documents = list(sorted_documents)[:k]
|
||||
sorted_metadatas = list(sorted_metadatas)[:k]
|
||||
|
||||
# Create the output dictionary
|
||||
merged_query_results = {
|
||||
"ids": [sorted_ids],
|
||||
result = {
|
||||
"distances": [sorted_distances],
|
||||
"metadatas": [sorted_metadatas],
|
||||
"documents": [sorted_documents],
|
||||
"embeddings": None,
|
||||
"uris": None,
|
||||
"data": None,
|
||||
"metadatas": [sorted_metadatas],
|
||||
}
|
||||
|
||||
return merged_query_results
|
||||
return result
|
||||
|
||||
|
||||
def query_collection(
|
||||
collection_names: List[str], query: str, k: int, embedding_function
|
||||
collection_names: List[str],
|
||||
query: str,
|
||||
embedding_function,
|
||||
k: int,
|
||||
):
|
||||
|
||||
results = []
|
||||
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
# if you use docker use the model from the environment variable
|
||||
collection = CHROMA_CLIENT.get_collection(
|
||||
name=collection_name,
|
||||
result = query_doc(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
k=k,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
|
||||
result = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=k,
|
||||
)
|
||||
results.append(result)
|
||||
except:
|
||||
pass
|
||||
return merge_and_sort_query_results(results, k=k)
|
||||
|
||||
return merge_and_sort_query_results(results, k)
|
||||
|
||||
def query_collection_with_hybrid_search(
|
||||
collection_names: List[str],
|
||||
query: str,
|
||||
embedding_function,
|
||||
k: int,
|
||||
reranking_function,
|
||||
r: float,
|
||||
):
|
||||
results = []
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
result = query_doc_with_hybrid_search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
results.append(result)
|
||||
except:
|
||||
pass
|
||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||
|
||||
|
||||
def rag_template(template: str, context: str, query: str):
|
||||
|
@ -100,8 +192,53 @@ def rag_template(template: str, context: str, query: str):
|
|||
return template
|
||||
|
||||
|
||||
def rag_messages(docs, messages, template, k, embedding_function):
|
||||
log.debug(f"docs: {docs}")
|
||||
def get_embedding_function(
|
||||
embedding_engine,
|
||||
embedding_model,
|
||||
embedding_function,
|
||||
openai_key,
|
||||
openai_url,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query: embedding_function.encode(query).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
if embedding_engine == "ollama":
|
||||
func = lambda query: generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": embedding_model,
|
||||
"prompt": query,
|
||||
}
|
||||
)
|
||||
)
|
||||
elif embedding_engine == "openai":
|
||||
func = lambda query: generate_openai_embeddings(
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
key=openai_key,
|
||||
url=openai_url,
|
||||
)
|
||||
|
||||
def generate_multiple(query, f):
|
||||
if isinstance(query, list):
|
||||
return [f(q) for q in query]
|
||||
else:
|
||||
return f(query)
|
||||
|
||||
return lambda query: generate_multiple(query, func)
|
||||
|
||||
|
||||
def rag_messages(
|
||||
docs,
|
||||
messages,
|
||||
template,
|
||||
embedding_function,
|
||||
k,
|
||||
reranking_function,
|
||||
r,
|
||||
hybrid_search,
|
||||
):
|
||||
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
|
||||
|
||||
last_user_message_idx = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
|
@ -128,40 +265,69 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
|||
content_type = None
|
||||
query = ""
|
||||
|
||||
extracted_collections = []
|
||||
relevant_contexts = []
|
||||
|
||||
for doc in docs:
|
||||
context = None
|
||||
|
||||
collection = doc.get("collection_name")
|
||||
if collection:
|
||||
collection = [collection]
|
||||
else:
|
||||
collection = doc.get("collection_names", [])
|
||||
|
||||
collection = set(collection).difference(extracted_collections)
|
||||
if not collection:
|
||||
log.debug(f"skipping {doc} as it has already been extracted")
|
||||
continue
|
||||
|
||||
try:
|
||||
if doc["type"] == "collection":
|
||||
context = query_collection(
|
||||
collection_names=doc["collection_names"],
|
||||
query=query,
|
||||
k=k,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
elif doc["type"] == "text":
|
||||
if doc["type"] == "text":
|
||||
context = doc["content"]
|
||||
else:
|
||||
context = query_doc(
|
||||
collection_name=doc["collection_name"],
|
||||
if hybrid_search:
|
||||
context = query_collection_with_hybrid_search(
|
||||
collection_names=(
|
||||
doc["collection_names"]
|
||||
if doc["type"] == "collection"
|
||||
else [doc["collection_name"]]
|
||||
),
|
||||
query=query,
|
||||
k=k,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
else:
|
||||
context = query_collection(
|
||||
collection_names=(
|
||||
doc["collection_names"]
|
||||
if doc["type"] == "collection"
|
||||
else [doc["collection_name"]]
|
||||
),
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
context = None
|
||||
|
||||
if context:
|
||||
relevant_contexts.append(context)
|
||||
|
||||
log.debug(f"relevant_contexts: {relevant_contexts}")
|
||||
extracted_collections.extend(collection)
|
||||
|
||||
context_string = ""
|
||||
for context in relevant_contexts:
|
||||
if context:
|
||||
context_string += " ".join(context["documents"][0]) + "\n"
|
||||
try:
|
||||
if "documents" in context:
|
||||
items = [item for item in context["documents"][0] if item is not None]
|
||||
context_string += "\n\n".join(items)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
context_string = context_string.strip()
|
||||
|
||||
ra_content = rag_template(
|
||||
template=template,
|
||||
|
@ -169,6 +335,8 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
|||
query=query,
|
||||
)
|
||||
|
||||
log.debug(f"ra_content: {ra_content}")
|
||||
|
||||
if content_type == "list":
|
||||
new_content = []
|
||||
for content_item in user_message["content"]:
|
||||
|
@ -188,3 +356,162 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
|||
messages[last_user_message_idx] = new_user_message
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def get_model_path(model: str, update_model: bool = False):
|
||||
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
|
||||
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
|
||||
|
||||
local_files_only = not update_model
|
||||
|
||||
snapshot_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
|
||||
log.debug(f"model: {model}")
|
||||
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
|
||||
|
||||
# Inspiration from upstream sentence_transformers
|
||||
if (
|
||||
os.path.exists(model)
|
||||
or ("\\" in model or model.count("/") > 1)
|
||||
and local_files_only
|
||||
):
|
||||
# If fully qualified path exists, return input, else set repo_id
|
||||
return model
|
||||
elif "/" not in model:
|
||||
# Set valid repo_id for model short-name
|
||||
model = "sentence-transformers" + "/" + model
|
||||
|
||||
snapshot_kwargs["repo_id"] = model
|
||||
|
||||
# Attempt to query the huggingface_hub library to determine the local path and/or to update
|
||||
try:
|
||||
model_repo_path = snapshot_download(**snapshot_kwargs)
|
||||
log.debug(f"model_repo_path: {model_repo_path}")
|
||||
return model_repo_path
|
||||
except Exception as e:
|
||||
log.exception(f"Cannot determine model snapshot path: {e}")
|
||||
return model
|
||||
|
||||
|
||||
def generate_openai_embeddings(
|
||||
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
|
||||
):
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{url}/embeddings",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
},
|
||||
json={"input": text, "model": model},
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "data" in data:
|
||||
return data["data"][0]["embedding"]
|
||||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
|
||||
|
||||
class ChromaRetriever(BaseRetriever):
|
||||
collection: Any
|
||||
embedding_function: Any
|
||||
top_n: int
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
query_embeddings = self.embedding_function(query)
|
||||
|
||||
results = self.collection.query(
|
||||
query_embeddings=[query_embeddings],
|
||||
n_results=self.top_n,
|
||||
)
|
||||
|
||||
ids = results["ids"][0]
|
||||
metadatas = results["metadatas"][0]
|
||||
documents = results["documents"][0]
|
||||
|
||||
results = []
|
||||
for idx in range(len(ids)):
|
||||
results.append(
|
||||
Document(
|
||||
metadata=metadatas[idx],
|
||||
page_content=documents[idx],
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
import operator
|
||||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.pydantic_v1 import Extra
|
||||
|
||||
from sentence_transformers import util
|
||||
|
||||
|
||||
class RerankCompressor(BaseDocumentCompressor):
|
||||
embedding_function: Any
|
||||
top_n: int
|
||||
reranking_function: Any
|
||||
r_score: float
|
||||
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
reranking = self.reranking_function is not None
|
||||
|
||||
if reranking:
|
||||
scores = self.reranking_function.predict(
|
||||
[(query, doc.page_content) for doc in documents]
|
||||
)
|
||||
else:
|
||||
query_embedding = self.embedding_function(query)
|
||||
document_embedding = self.embedding_function(
|
||||
[doc.page_content for doc in documents]
|
||||
)
|
||||
scores = util.cos_sim(query_embedding, document_embedding)[0]
|
||||
|
||||
docs_with_scores = list(zip(documents, scores.tolist()))
|
||||
if self.r_score:
|
||||
docs_with_scores = [
|
||||
(d, s) for d, s in docs_with_scores if s >= self.r_score
|
||||
]
|
||||
|
||||
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
||||
final_results = []
|
||||
for doc, doc_score in result[: self.top_n]:
|
||||
metadata = doc.metadata
|
||||
metadata["score"] = doc_score
|
||||
doc = Document(
|
||||
page_content=doc.page_content,
|
||||
metadata=metadata,
|
||||
)
|
||||
final_results.append(doc)
|
||||
return final_results
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from peewee import *
|
||||
from peewee_migrate import Router
|
||||
from config import SRC_LOG_LEVELS, DATA_DIR
|
||||
from playhouse.db_url import connect
|
||||
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL
|
||||
import os
|
||||
import logging
|
||||
|
||||
|
@ -11,12 +12,12 @@ log.setLevel(SRC_LOG_LEVELS["DB"])
|
|||
if os.path.exists(f"{DATA_DIR}/ollama.db"):
|
||||
# Rename the file
|
||||
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
|
||||
log.info("File renamed successfully.")
|
||||
log.info("Database migrated from Ollama-WebUI successfully.")
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
DB = SqliteDatabase(f"{DATA_DIR}/webui.db")
|
||||
DB = connect(DATABASE_URL)
|
||||
log.info(f"Connected to a {DB.__class__.__name__} database.")
|
||||
router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log)
|
||||
router.run()
|
||||
DB.connect(reuse_if_open=True)
|
||||
|
|
|
@ -37,6 +37,18 @@ with suppress(ImportError):
|
|||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# We perform different migrations for SQLite and other databases
|
||||
# This is because SQLite is very loose with enforcing its schema, and trying to migrate other databases like SQLite
|
||||
# will require per-database SQL queries.
|
||||
# Instead, we assume that because external DB support was added at a later date, it is safe to assume a newer base
|
||||
# schema instead of trying to migrate from an older schema.
|
||||
if isinstance(database, pw.SqliteDatabase):
|
||||
migrate_sqlite(migrator, database, fake=fake)
|
||||
else:
|
||||
migrate_external(migrator, database, fake=fake)
|
||||
|
||||
|
||||
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
@migrator.create_model
|
||||
class Auth(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
|
@ -53,7 +65,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||
user_id = pw.CharField(max_length=255)
|
||||
title = pw.CharField()
|
||||
chat = pw.TextField()
|
||||
timestamp = pw.DateField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "chat"
|
||||
|
@ -64,7 +76,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||
tag_name = pw.CharField(max_length=255)
|
||||
chat_id = pw.CharField(max_length=255)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
timestamp = pw.DateField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "chatidtag"
|
||||
|
@ -78,7 +90,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||
filename = pw.CharField()
|
||||
content = pw.TextField(null=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
timestamp = pw.DateField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "document"
|
||||
|
@ -89,7 +101,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||
tag_name = pw.CharField(max_length=255, unique=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
modelfile = pw.TextField()
|
||||
timestamp = pw.DateField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "modelfile"
|
||||
|
@ -101,7 +113,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||
user_id = pw.CharField(max_length=255)
|
||||
title = pw.CharField()
|
||||
content = pw.TextField()
|
||||
timestamp = pw.DateField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "prompt"
|
||||
|
@ -123,7 +135,100 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||
email = pw.CharField(max_length=255)
|
||||
role = pw.CharField(max_length=255)
|
||||
profile_image_url = pw.CharField(max_length=255)
|
||||
timestamp = pw.DateField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "user"
|
||||
|
||||
|
||||
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
@migrator.create_model
|
||||
class Auth(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
email = pw.CharField(max_length=255)
|
||||
password = pw.TextField()
|
||||
active = pw.BooleanField()
|
||||
|
||||
class Meta:
|
||||
table_name = "auth"
|
||||
|
||||
@migrator.create_model
|
||||
class Chat(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
title = pw.TextField()
|
||||
chat = pw.TextField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "chat"
|
||||
|
||||
@migrator.create_model
|
||||
class ChatIdTag(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
tag_name = pw.CharField(max_length=255)
|
||||
chat_id = pw.CharField(max_length=255)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "chatidtag"
|
||||
|
||||
@migrator.create_model
|
||||
class Document(pw.Model):
|
||||
id = pw.AutoField()
|
||||
collection_name = pw.CharField(max_length=255, unique=True)
|
||||
name = pw.CharField(max_length=255, unique=True)
|
||||
title = pw.TextField()
|
||||
filename = pw.TextField()
|
||||
content = pw.TextField(null=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "document"
|
||||
|
||||
@migrator.create_model
|
||||
class Modelfile(pw.Model):
|
||||
id = pw.AutoField()
|
||||
tag_name = pw.CharField(max_length=255, unique=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
modelfile = pw.TextField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "modelfile"
|
||||
|
||||
@migrator.create_model
|
||||
class Prompt(pw.Model):
|
||||
id = pw.AutoField()
|
||||
command = pw.CharField(max_length=255, unique=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
title = pw.TextField()
|
||||
content = pw.TextField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "prompt"
|
||||
|
||||
@migrator.create_model
|
||||
class Tag(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
name = pw.CharField(max_length=255)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
data = pw.TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
table_name = "tag"
|
||||
|
||||
@migrator.create_model
|
||||
class User(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
name = pw.CharField(max_length=255)
|
||||
email = pw.CharField(max_length=255)
|
||||
role = pw.CharField(max_length=255)
|
||||
profile_image_url = pw.TextField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "user"
|
||||
|
|
46
backend/apps/web/internal/migrations/004_add_archived.py
Normal file
46
backend/apps/web/internal/migrations/004_add_archived.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields("chat", archived=pw.BooleanField(default=False))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("chat", "archived")
|
130
backend/apps/web/internal/migrations/005_add_updated_at.py
Normal file
130
backend/apps/web/internal/migrations/005_add_updated_at.py
Normal file
|
@ -0,0 +1,130 @@
|
|||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
if isinstance(database, pw.SqliteDatabase):
|
||||
migrate_sqlite(migrator, database, fake=fake)
|
||||
else:
|
||||
migrate_external(migrator, database, fake=fake)
|
||||
|
||||
|
||||
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
# Adding fields created_at and updated_at to the 'chat' table
|
||||
migrator.add_fields(
|
||||
"chat",
|
||||
created_at=pw.DateTimeField(null=True), # Allow null for transition
|
||||
updated_at=pw.DateTimeField(null=True), # Allow null for transition
|
||||
)
|
||||
|
||||
# Populate the new fields from an existing 'timestamp' field
|
||||
migrator.sql(
|
||||
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
|
||||
)
|
||||
|
||||
# Now that the data has been copied, remove the original 'timestamp' field
|
||||
migrator.remove_fields("chat", "timestamp")
|
||||
|
||||
# Update the fields to be not null now that they are populated
|
||||
migrator.change_fields(
|
||||
"chat",
|
||||
created_at=pw.DateTimeField(null=False),
|
||||
updated_at=pw.DateTimeField(null=False),
|
||||
)
|
||||
|
||||
|
||||
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
# Adding fields created_at and updated_at to the 'chat' table
|
||||
migrator.add_fields(
|
||||
"chat",
|
||||
created_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
)
|
||||
|
||||
# Populate the new fields from an existing 'timestamp' field
|
||||
migrator.sql(
|
||||
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
|
||||
)
|
||||
|
||||
# Now that the data has been copied, remove the original 'timestamp' field
|
||||
migrator.remove_fields("chat", "timestamp")
|
||||
|
||||
# Update the fields to be not null now that they are populated
|
||||
migrator.change_fields(
|
||||
"chat",
|
||||
created_at=pw.BigIntegerField(null=False),
|
||||
updated_at=pw.BigIntegerField(null=False),
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
if isinstance(database, pw.SqliteDatabase):
|
||||
rollback_sqlite(migrator, database, fake=fake)
|
||||
else:
|
||||
rollback_external(migrator, database, fake=fake)
|
||||
|
||||
|
||||
def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
# Recreate the timestamp field initially allowing null values for safe transition
|
||||
migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True))
|
||||
|
||||
# Copy the earliest created_at date back into the new timestamp field
|
||||
# This assumes created_at was originally a copy of timestamp
|
||||
migrator.sql("UPDATE chat SET timestamp = created_at")
|
||||
|
||||
# Remove the created_at and updated_at fields
|
||||
migrator.remove_fields("chat", "created_at", "updated_at")
|
||||
|
||||
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
||||
migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False))
|
||||
|
||||
|
||||
def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
# Recreate the timestamp field initially allowing null values for safe transition
|
||||
migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True))
|
||||
|
||||
# Copy the earliest created_at date back into the new timestamp field
|
||||
# This assumes created_at was originally a copy of timestamp
|
||||
migrator.sql("UPDATE chat SET timestamp = created_at")
|
||||
|
||||
# Remove the created_at and updated_at fields
|
||||
migrator.remove_fields("chat", "created_at", "updated_at")
|
||||
|
||||
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
||||
migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False))
|
|
@ -0,0 +1,130 @@
|
|||
"""Peewee migrations -- 006_migrate_timestamps_and_charfields.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Alter the tables with timestamps
|
||||
migrator.change_fields(
|
||||
"chatidtag",
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"document",
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"modelfile",
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"prompt",
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
# Alter the tables with varchar to text where necessary
|
||||
migrator.change_fields(
|
||||
"auth",
|
||||
password=pw.TextField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"chat",
|
||||
title=pw.TextField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"document",
|
||||
title=pw.TextField(),
|
||||
filename=pw.TextField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"prompt",
|
||||
title=pw.TextField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
profile_image_url=pw.TextField(),
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
if isinstance(database, pw.SqliteDatabase):
|
||||
# Alter the tables with timestamps
|
||||
migrator.change_fields(
|
||||
"chatidtag",
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"document",
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"modelfile",
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"prompt",
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"auth",
|
||||
password=pw.CharField(max_length=255),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"chat",
|
||||
title=pw.CharField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"document",
|
||||
title=pw.CharField(),
|
||||
filename=pw.CharField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"prompt",
|
||||
title=pw.CharField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
profile_image_url=pw.CharField(),
|
||||
)
|
|
@ -0,0 +1,79 @@
|
|||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Adding fields created_at and updated_at to the 'user' table
|
||||
migrator.add_fields(
|
||||
"user",
|
||||
created_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
last_active_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
)
|
||||
|
||||
# Populate the new fields from an existing 'timestamp' field
|
||||
migrator.sql(
|
||||
'UPDATE "user" SET created_at = timestamp, updated_at = timestamp, last_active_at = timestamp WHERE timestamp IS NOT NULL'
|
||||
)
|
||||
|
||||
# Now that the data has been copied, remove the original 'timestamp' field
|
||||
migrator.remove_fields("user", "timestamp")
|
||||
|
||||
# Update the fields to be not null now that they are populated
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
created_at=pw.BigIntegerField(null=False),
|
||||
updated_at=pw.BigIntegerField(null=False),
|
||||
last_active_at=pw.BigIntegerField(null=False),
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
# Recreate the timestamp field initially allowing null values for safe transition
|
||||
migrator.add_fields("user", timestamp=pw.BigIntegerField(null=True))
|
||||
|
||||
# Copy the earliest created_at date back into the new timestamp field
|
||||
# This assumes created_at was originally a copy of timestamp
|
||||
migrator.sql('UPDATE "user" SET timestamp = created_at')
|
||||
|
||||
# Remove the created_at and updated_at fields
|
||||
migrator.remove_fields("user", "created_at", "updated_at", "last_active_at")
|
||||
|
||||
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
||||
migrator.change_fields("user", timestamp=pw.BigIntegerField(null=False))
|
|
@ -23,7 +23,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|||
class Auth(Model):
|
||||
id = CharField(unique=True)
|
||||
email = CharField()
|
||||
password = CharField()
|
||||
password = TextField()
|
||||
active = BooleanField()
|
||||
|
||||
class Meta:
|
||||
|
@ -86,6 +86,11 @@ class SignupForm(BaseModel):
|
|||
name: str
|
||||
email: str
|
||||
password: str
|
||||
profile_image_url: Optional[str] = "/user.png"
|
||||
|
||||
|
||||
class AddUserForm(SignupForm):
|
||||
role: Optional[str] = "pending"
|
||||
|
||||
|
||||
class AuthsTable:
|
||||
|
@ -94,7 +99,12 @@ class AuthsTable:
|
|||
self.db.create_tables([Auth])
|
||||
|
||||
def insert_new_auth(
|
||||
self, email: str, password: str, name: str, role: str = "pending"
|
||||
self,
|
||||
email: str,
|
||||
password: str,
|
||||
name: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
) -> Optional[UserModel]:
|
||||
log.info("insert_new_auth")
|
||||
|
||||
|
@ -105,7 +115,7 @@ class AuthsTable:
|
|||
)
|
||||
result = Auth.create(**auth.model_dump())
|
||||
|
||||
user = Users.insert_new_user(id, name, email, role)
|
||||
user = Users.insert_new_user(id, name, email, profile_image_url, role)
|
||||
|
||||
if result and user:
|
||||
return user
|
||||
|
|
|
@ -17,10 +17,14 @@ from apps.web.internal.db import DB
|
|||
class Chat(Model):
|
||||
id = CharField(unique=True)
|
||||
user_id = CharField()
|
||||
title = CharField()
|
||||
title = TextField()
|
||||
chat = TextField() # Save Chat JSON as Text
|
||||
timestamp = DateField()
|
||||
|
||||
created_at = BigIntegerField()
|
||||
updated_at = BigIntegerField()
|
||||
|
||||
share_id = CharField(null=True, unique=True)
|
||||
archived = BooleanField(default=False)
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
@ -31,8 +35,12 @@ class ChatModel(BaseModel):
|
|||
user_id: str
|
||||
title: str
|
||||
chat: str
|
||||
timestamp: int # timestamp in epoch
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
share_id: Optional[str] = None
|
||||
archived: bool = False
|
||||
|
||||
|
||||
####################
|
||||
|
@ -53,13 +61,17 @@ class ChatResponse(BaseModel):
|
|||
user_id: str
|
||||
title: str
|
||||
chat: dict
|
||||
timestamp: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
share_id: Optional[str] = None # id of the chat to be shared
|
||||
archived: bool
|
||||
|
||||
|
||||
class ChatTitleIdResponse(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
updated_at: int
|
||||
created_at: int
|
||||
|
||||
|
||||
class ChatTable:
|
||||
|
@ -77,7 +89,8 @@ class ChatTable:
|
|||
form_data.chat["title"] if "title" in form_data.chat else "New Chat"
|
||||
),
|
||||
"chat": json.dumps(form_data.chat),
|
||||
"timestamp": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -89,7 +102,7 @@ class ChatTable:
|
|||
query = Chat.update(
|
||||
chat=json.dumps(chat),
|
||||
title=chat["title"] if "title" in chat else "New Chat",
|
||||
timestamp=int(time.time()),
|
||||
updated_at=int(time.time()),
|
||||
).where(Chat.id == id)
|
||||
query.execute()
|
||||
|
||||
|
@ -111,7 +124,8 @@ class ChatTable:
|
|||
"user_id": f"shared-{chat_id}",
|
||||
"title": chat.title,
|
||||
"chat": chat.chat,
|
||||
"timestamp": int(time.time()),
|
||||
"created_at": chat.created_at,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
shared_result = Chat.create(**shared_chat.model_dump())
|
||||
|
@ -163,40 +177,55 @@ class ChatTable:
|
|||
except:
|
||||
return None
|
||||
|
||||
def get_chat_lists_by_user_id(
|
||||
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
chat = self.get_chat_by_id(id)
|
||||
query = Chat.update(
|
||||
archived=(not chat.archived),
|
||||
).where(Chat.id == id)
|
||||
|
||||
query.execute()
|
||||
|
||||
chat = Chat.get(Chat.id == id)
|
||||
return ChatModel(**model_to_dict(chat))
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_archived_chat_list_by_user_id(
|
||||
self, user_id: str, skip: int = 0, limit: int = 50
|
||||
) -> List[ChatModel]:
|
||||
return [
|
||||
ChatModel(**model_to_dict(chat))
|
||||
for chat in Chat.select()
|
||||
.where(Chat.archived == True)
|
||||
.where(Chat.user_id == user_id)
|
||||
.order_by(Chat.timestamp.desc())
|
||||
.order_by(Chat.updated_at.desc())
|
||||
# .limit(limit)
|
||||
# .offset(skip)
|
||||
]
|
||||
|
||||
def get_chat_lists_by_chat_ids(
|
||||
def get_chat_list_by_user_id(
|
||||
self, user_id: str, skip: int = 0, limit: int = 50
|
||||
) -> List[ChatModel]:
|
||||
return [
|
||||
ChatModel(**model_to_dict(chat))
|
||||
for chat in Chat.select()
|
||||
.where(Chat.archived == False)
|
||||
.where(Chat.user_id == user_id)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
# .limit(limit)
|
||||
# .offset(skip)
|
||||
]
|
||||
|
||||
def get_chat_list_by_chat_ids(
|
||||
self, chat_ids: List[str], skip: int = 0, limit: int = 50
|
||||
) -> List[ChatModel]:
|
||||
return [
|
||||
ChatModel(**model_to_dict(chat))
|
||||
for chat in Chat.select()
|
||||
.where(Chat.archived == False)
|
||||
.where(Chat.id.in_(chat_ids))
|
||||
.order_by(Chat.timestamp.desc())
|
||||
]
|
||||
|
||||
def get_all_chats(self) -> List[ChatModel]:
|
||||
return [
|
||||
ChatModel(**model_to_dict(chat))
|
||||
for chat in Chat.select().order_by(Chat.timestamp.desc())
|
||||
]
|
||||
|
||||
def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
|
||||
return [
|
||||
ChatModel(**model_to_dict(chat))
|
||||
for chat in Chat.select()
|
||||
.where(Chat.user_id == user_id)
|
||||
.order_by(Chat.timestamp.desc())
|
||||
.order_by(Chat.updated_at.desc())
|
||||
]
|
||||
|
||||
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
|
||||
|
@ -206,6 +235,18 @@ class ChatTable:
|
|||
except:
|
||||
return None
|
||||
|
||||
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
chat = Chat.get(Chat.share_id == id)
|
||||
|
||||
if chat:
|
||||
chat = Chat.get(Chat.id == id)
|
||||
return ChatModel(**model_to_dict(chat))
|
||||
else:
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
|
||||
|
@ -216,9 +257,28 @@ class ChatTable:
|
|||
def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
|
||||
return [
|
||||
ChatModel(**model_to_dict(chat))
|
||||
for chat in Chat.select().limit(limit).offset(skip)
|
||||
for chat in Chat.select().order_by(Chat.updated_at.desc())
|
||||
# .limit(limit).offset(skip)
|
||||
]
|
||||
|
||||
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
|
||||
return [
|
||||
ChatModel(**model_to_dict(chat))
|
||||
for chat in Chat.select()
|
||||
.where(Chat.user_id == user_id)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
# .limit(limit).offset(skip)
|
||||
]
|
||||
|
||||
def delete_chat_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
query = Chat.delete().where((Chat.id == id))
|
||||
query.execute() # Remove the rows, return number of rows removed.
|
||||
|
||||
return True and self.delete_shared_chat_by_chat_id(id)
|
||||
except:
|
||||
return False
|
||||
|
||||
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
try:
|
||||
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
|
||||
|
|
|
@ -25,11 +25,11 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|||
class Document(Model):
|
||||
collection_name = CharField(unique=True)
|
||||
name = CharField(unique=True)
|
||||
title = CharField()
|
||||
filename = CharField()
|
||||
title = TextField()
|
||||
filename = TextField()
|
||||
content = TextField(null=True)
|
||||
user_id = CharField()
|
||||
timestamp = DateField()
|
||||
timestamp = BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
|
|
@ -20,7 +20,7 @@ class Modelfile(Model):
|
|||
tag_name = CharField(unique=True)
|
||||
user_id = CharField()
|
||||
modelfile = TextField()
|
||||
timestamp = DateField()
|
||||
timestamp = BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
|
|
@ -19,9 +19,9 @@ import json
|
|||
class Prompt(Model):
|
||||
command = CharField(unique=True)
|
||||
user_id = CharField()
|
||||
title = CharField()
|
||||
title = TextField()
|
||||
content = TextField()
|
||||
timestamp = DateField()
|
||||
timestamp = BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
|
|
@ -35,7 +35,7 @@ class ChatIdTag(Model):
|
|||
tag_name = CharField()
|
||||
chat_id = CharField()
|
||||
user_id = CharField()
|
||||
timestamp = DateField()
|
||||
timestamp = BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
@ -136,7 +136,9 @@ class TagTable:
|
|||
|
||||
return [
|
||||
TagModel(**model_to_dict(tag))
|
||||
for tag in Tag.select().where(Tag.name.in_(tag_names))
|
||||
for tag in Tag.select()
|
||||
.where(Tag.user_id == user_id)
|
||||
.where(Tag.name.in_(tag_names))
|
||||
]
|
||||
|
||||
def get_tags_by_chat_id_and_user_id(
|
||||
|
@ -151,7 +153,9 @@ class TagTable:
|
|||
|
||||
return [
|
||||
TagModel(**model_to_dict(tag))
|
||||
for tag in Tag.select().where(Tag.name.in_(tag_names))
|
||||
for tag in Tag.select()
|
||||
.where(Tag.user_id == user_id)
|
||||
.where(Tag.name.in_(tag_names))
|
||||
]
|
||||
|
||||
def get_chat_ids_by_tag_name_and_user_id(
|
||||
|
|
|
@ -18,8 +18,12 @@ class User(Model):
|
|||
name = CharField()
|
||||
email = CharField()
|
||||
role = CharField()
|
||||
profile_image_url = CharField()
|
||||
timestamp = DateField()
|
||||
profile_image_url = TextField()
|
||||
|
||||
last_active_at = BigIntegerField()
|
||||
updated_at = BigIntegerField()
|
||||
created_at = BigIntegerField()
|
||||
|
||||
api_key = CharField(null=True, unique=True)
|
||||
|
||||
class Meta:
|
||||
|
@ -31,8 +35,12 @@ class UserModel(BaseModel):
|
|||
name: str
|
||||
email: str
|
||||
role: str = "pending"
|
||||
profile_image_url: str = "/user.png"
|
||||
timestamp: int # timestamp in epoch
|
||||
profile_image_url: str
|
||||
|
||||
last_active_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
api_key: Optional[str] = None
|
||||
|
||||
|
||||
|
@ -59,7 +67,12 @@ class UsersTable:
|
|||
self.db.create_tables([User])
|
||||
|
||||
def insert_new_user(
|
||||
self, id: str, name: str, email: str, role: str = "pending"
|
||||
self,
|
||||
id: str,
|
||||
name: str,
|
||||
email: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
) -> Optional[UserModel]:
|
||||
user = UserModel(
|
||||
**{
|
||||
|
@ -67,8 +80,10 @@ class UsersTable:
|
|||
"name": name,
|
||||
"email": email,
|
||||
"role": role,
|
||||
"profile_image_url": "/user.png",
|
||||
"timestamp": int(time.time()),
|
||||
"profile_image_url": profile_image_url,
|
||||
"last_active_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
result = User.create(**user.model_dump())
|
||||
|
@ -108,6 +123,13 @@ class UsersTable:
|
|||
def get_num_users(self) -> Optional[int]:
|
||||
return User.select().count()
|
||||
|
||||
def get_first_user(self) -> UserModel:
|
||||
try:
|
||||
user = User.select().order_by(User.created_at).first()
|
||||
return UserModel(**model_to_dict(user))
|
||||
except:
|
||||
return None
|
||||
|
||||
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
|
||||
try:
|
||||
query = User.update(role=role).where(User.id == id)
|
||||
|
@ -132,6 +154,16 @@ class UsersTable:
|
|||
except:
|
||||
return None
|
||||
|
||||
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
||||
try:
|
||||
query = User.update(last_active_at=int(time.time())).where(User.id == id)
|
||||
query.execute()
|
||||
|
||||
user = User.get(User.id == id)
|
||||
return UserModel(**model_to_dict(user))
|
||||
except:
|
||||
return None
|
||||
|
||||
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
|
||||
try:
|
||||
query = User.update(**updated).where(User.id == id)
|
||||
|
|
|
@ -1,14 +1,19 @@
|
|||
from fastapi import Request
|
||||
import logging
|
||||
|
||||
from fastapi import Request, UploadFile, File
|
||||
from fastapi import Depends, HTTPException, status
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import re
|
||||
import uuid
|
||||
import csv
|
||||
|
||||
|
||||
from apps.web.models.auths import (
|
||||
SigninForm,
|
||||
SignupForm,
|
||||
AddUserForm,
|
||||
UpdateProfileForm,
|
||||
UpdatePasswordForm,
|
||||
UserResponse,
|
||||
|
@ -163,7 +168,11 @@ async def signup(request: Request, form_data: SignupForm):
|
|||
)
|
||||
hashed = get_password_hash(form_data.password)
|
||||
user = Auths.insert_new_auth(
|
||||
form_data.email.lower(), hashed, form_data.name, role
|
||||
form_data.email.lower(),
|
||||
hashed,
|
||||
form_data.name,
|
||||
form_data.profile_image_url,
|
||||
role,
|
||||
)
|
||||
|
||||
if user:
|
||||
|
@ -199,6 +208,51 @@ async def signup(request: Request, form_data: SignupForm):
|
|||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
|
||||
|
||||
############################
|
||||
# AddUser
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/add", response_model=SigninResponse)
|
||||
async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||
|
||||
if not validate_email_format(form_data.email.lower()):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||
)
|
||||
|
||||
if Users.get_user_by_email(form_data.email.lower()):
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
try:
|
||||
|
||||
print(form_data)
|
||||
hashed = get_password_hash(form_data.password)
|
||||
user = Auths.insert_new_auth(
|
||||
form_data.email.lower(),
|
||||
hashed,
|
||||
form_data.name,
|
||||
form_data.profile_image_url,
|
||||
form_data.role,
|
||||
)
|
||||
|
||||
if user:
|
||||
token = create_token(data={"id": user.id})
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
"profile_image_url": user.profile_image_url,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
|
||||
|
||||
############################
|
||||
# ToggleSignUp
|
||||
############################
|
||||
|
|
|
@ -28,7 +28,7 @@ from apps.web.models.tags import (
|
|||
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
@ -36,27 +36,73 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|||
router = APIRouter()
|
||||
|
||||
############################
|
||||
# GetChats
|
||||
# GetChatList
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=List[ChatTitleIdResponse])
|
||||
async def get_user_chats(
|
||||
@router.get("/list", response_model=List[ChatTitleIdResponse])
|
||||
async def get_session_user_chat_list(
|
||||
user=Depends(get_current_user), skip: int = 0, limit: int = 50
|
||||
):
|
||||
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
|
||||
return Chats.get_chat_list_by_user_id(user.id, skip, limit)
|
||||
|
||||
|
||||
############################
|
||||
# GetAllChats
|
||||
# DeleteAllChats
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/", response_model=bool)
|
||||
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
|
||||
|
||||
if (
|
||||
user.role == "user"
|
||||
and not request.app.state.USER_PERMISSIONS["chat"]["deletion"]
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
result = Chats.delete_chats_by_user_id(user.id)
|
||||
return result
|
||||
|
||||
|
||||
############################
|
||||
# GetUserChatList
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
|
||||
async def get_user_chat_list_by_user_id(
|
||||
user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50
|
||||
):
|
||||
return Chats.get_chat_list_by_user_id(user_id, skip, limit)
|
||||
|
||||
|
||||
############################
|
||||
# GetArchivedChats
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/archived", response_model=List[ChatTitleIdResponse])
|
||||
async def get_archived_session_user_chat_list(
|
||||
user=Depends(get_current_user), skip: int = 0, limit: int = 50
|
||||
):
|
||||
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
|
||||
|
||||
|
||||
############################
|
||||
# GetChats
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/all", response_model=List[ChatResponse])
|
||||
async def get_all_user_chats(user=Depends(get_current_user)):
|
||||
async def get_user_chats(user=Depends(get_current_user)):
|
||||
return [
|
||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
for chat in Chats.get_all_chats_by_user_id(user.id)
|
||||
for chat in Chats.get_chats_by_user_id(user.id)
|
||||
]
|
||||
|
||||
|
||||
|
@ -67,9 +113,14 @@ async def get_all_user_chats(user=Depends(get_current_user)):
|
|||
|
||||
@router.get("/all/db", response_model=List[ChatResponse])
|
||||
async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
|
||||
if not ENABLE_ADMIN_EXPORT:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
return [
|
||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
for chat in Chats.get_all_chats()
|
||||
for chat in Chats.get_chats()
|
||||
]
|
||||
|
||||
|
||||
|
@ -90,45 +141,6 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
|
|||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetAllTags
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/tags/all", response_model=List[TagModel])
|
||||
async def get_all_tags(user=Depends(get_current_user)):
|
||||
try:
|
||||
tags = Tags.get_tags_by_user_id(user.id)
|
||||
return tags
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetChatsByTags
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
|
||||
async def get_user_chats_by_tag_name(
|
||||
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
|
||||
):
|
||||
chat_ids = [
|
||||
chat_id_tag.chat_id
|
||||
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
|
||||
]
|
||||
|
||||
chats = Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit)
|
||||
|
||||
if len(chats) == 0:
|
||||
Tags.delete_tag_by_tag_name_and_user_id(tag_name, user.id)
|
||||
|
||||
return chats
|
||||
|
||||
|
||||
############################
|
||||
# GetChatById
|
||||
############################
|
||||
|
@ -176,10 +188,11 @@ async def update_chat_by_id(
|
|||
@router.delete("/{id}", response_model=bool)
|
||||
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
|
||||
|
||||
if (
|
||||
user.role == "user"
|
||||
and not request.app.state.USER_PERMISSIONS["chat"]["deletion"]
|
||||
):
|
||||
if user.role == "admin":
|
||||
result = Chats.delete_chat_by_id(id)
|
||||
return result
|
||||
else:
|
||||
if not request.app.state.USER_PERMISSIONS["chat"]["deletion"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
|
@ -189,6 +202,23 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
|
|||
return result
|
||||
|
||||
|
||||
############################
|
||||
# ArchiveChat
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}/archive", response_model=Optional[ChatResponse])
|
||||
async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
chat = Chats.toggle_chat_archive_by_id(id)
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# ShareChatById
|
||||
############################
|
||||
|
@ -251,6 +281,14 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
|
|||
|
||||
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
|
||||
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
|
||||
if user.role == "pending":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role == "user":
|
||||
chat = Chats.get_chat_by_share_id(share_id)
|
||||
elif user.role == "admin":
|
||||
chat = Chats.get_chat_by_id(share_id)
|
||||
|
||||
if chat:
|
||||
|
@ -261,6 +299,45 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
|
|||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetAllTags
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/tags/all", response_model=List[TagModel])
|
||||
async def get_all_tags(user=Depends(get_current_user)):
|
||||
try:
|
||||
tags = Tags.get_tags_by_user_id(user.id)
|
||||
return tags
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetChatsByTags
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
|
||||
async def get_user_chat_list_by_tag_name(
|
||||
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
|
||||
):
|
||||
chat_ids = [
|
||||
chat_id_tag.chat_id
|
||||
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
|
||||
]
|
||||
|
||||
chats = Chats.get_chat_list_by_chat_ids(chat_ids, skip, limit)
|
||||
|
||||
if len(chats) == 0:
|
||||
Tags.delete_tag_by_tag_name_and_user_id(tag_name, user.id)
|
||||
|
||||
return chats
|
||||
|
||||
|
||||
############################
|
||||
# GetChatTagsById
|
||||
############################
|
||||
|
@ -341,24 +418,3 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
|
|||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteAllChats
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/", response_model=bool)
|
||||
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
|
||||
|
||||
if (
|
||||
user.role == "user"
|
||||
and not request.app.state.USER_PERMISSIONS["chat"]["deletion"]
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
result = Chats.delete_chats_by_user_id(user.id)
|
||||
return result
|
||||
|
|
|
@ -58,7 +58,7 @@ async def update_user_permissions(
|
|||
@router.post("/update/role", response_model=Optional[UserModel])
|
||||
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
|
||||
|
||||
if user.id != form_data.id:
|
||||
if user.id != form_data.id and form_data.id != Users.get_first_user().id:
|
||||
return Users.update_user_role_by_id(form_data.id, form_data.role)
|
||||
|
||||
raise HTTPException(
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from fastapi import APIRouter, UploadFile, File, Response
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from peewee import SqliteDatabase
|
||||
from starlette.responses import StreamingResponse, FileResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -7,11 +8,11 @@ from pydantic import BaseModel
|
|||
from fpdf import FPDF
|
||||
import markdown
|
||||
|
||||
|
||||
from apps.web.internal.db import DB
|
||||
from utils.utils import get_admin_user
|
||||
from utils.misc import calculate_sha256, get_gravatar_url
|
||||
|
||||
from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR
|
||||
from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR, ENABLE_ADMIN_EXPORT
|
||||
from constants import ERROR_MESSAGES
|
||||
from typing import List
|
||||
|
||||
|
@ -91,9 +92,18 @@ async def download_chat_as_pdf(
|
|||
|
||||
@router.get("/db/download")
|
||||
async def download_db(user=Depends(get_admin_user)):
|
||||
|
||||
if not ENABLE_ADMIN_EXPORT:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
if not isinstance(DB, SqliteDatabase):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
|
||||
)
|
||||
return FileResponse(
|
||||
f"{DATA_DIR}/webui.db",
|
||||
DB.database,
|
||||
media_type="application/octet-stream",
|
||||
filename="webui.db",
|
||||
)
|
||||
|
|
|
@ -18,6 +18,51 @@ from secrets import token_bytes
|
|||
from constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
####################################
|
||||
# LOGGING
|
||||
####################################
|
||||
|
||||
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
|
||||
|
||||
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
|
||||
if GLOBAL_LOG_LEVEL in log_levels:
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
|
||||
else:
|
||||
GLOBAL_LOG_LEVEL = "INFO"
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
||||
|
||||
log_sources = [
|
||||
"AUDIO",
|
||||
"COMFYUI",
|
||||
"CONFIG",
|
||||
"DB",
|
||||
"IMAGES",
|
||||
"LITELLM",
|
||||
"MAIN",
|
||||
"MODELS",
|
||||
"OLLAMA",
|
||||
"OPENAI",
|
||||
"RAG",
|
||||
"WEBHOOK",
|
||||
]
|
||||
|
||||
SRC_LOG_LEVELS = {}
|
||||
|
||||
for source in log_sources:
|
||||
log_env_var = source + "_LOG_LEVEL"
|
||||
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
|
||||
if SRC_LOG_LEVELS[source] not in log_levels:
|
||||
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
|
||||
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
|
||||
|
||||
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
|
||||
|
||||
####################################
|
||||
# Load .env file
|
||||
####################################
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
|
||||
|
@ -26,9 +71,10 @@ except ImportError:
|
|||
log.warning("dotenv not installed, skipping...")
|
||||
|
||||
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
|
||||
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
||||
if WEBUI_NAME != "Open WebUI":
|
||||
WEBUI_NAME += " (Open WebUI)"
|
||||
|
||||
shutil.copyfile("../build/favicon.png", "./static/favicon.png")
|
||||
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
||||
|
||||
####################################
|
||||
# ENV (dev,test,prod)
|
||||
|
@ -103,47 +149,30 @@ for version in soup.find_all("h2"):
|
|||
|
||||
CHANGELOG = changelog_json
|
||||
|
||||
####################################
|
||||
# DATA/FRONTEND BUILD DIR
|
||||
####################################
|
||||
|
||||
DATA_DIR = str(Path(os.getenv("DATA_DIR", "./data")).resolve())
|
||||
FRONTEND_BUILD_DIR = str(Path(os.getenv("FRONTEND_BUILD_DIR", "../build")))
|
||||
|
||||
try:
|
||||
with open(f"{DATA_DIR}/config.json", "r") as f:
|
||||
CONFIG_DATA = json.load(f)
|
||||
except:
|
||||
CONFIG_DATA = {}
|
||||
|
||||
####################################
|
||||
# LOGGING
|
||||
# Static DIR
|
||||
####################################
|
||||
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
|
||||
|
||||
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
|
||||
if GLOBAL_LOG_LEVEL in log_levels:
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
|
||||
STATIC_DIR = str(Path(os.getenv("STATIC_DIR", "./static")).resolve())
|
||||
|
||||
frontend_favicon = f"{FRONTEND_BUILD_DIR}/favicon.png"
|
||||
if os.path.exists(frontend_favicon):
|
||||
shutil.copyfile(frontend_favicon, f"{STATIC_DIR}/favicon.png")
|
||||
else:
|
||||
GLOBAL_LOG_LEVEL = "INFO"
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
||||
|
||||
log_sources = [
|
||||
"AUDIO",
|
||||
"COMFYUI",
|
||||
"CONFIG",
|
||||
"DB",
|
||||
"IMAGES",
|
||||
"LITELLM",
|
||||
"MAIN",
|
||||
"MODELS",
|
||||
"OLLAMA",
|
||||
"OPENAI",
|
||||
"RAG",
|
||||
"WEBHOOK",
|
||||
]
|
||||
|
||||
SRC_LOG_LEVELS = {}
|
||||
|
||||
for source in log_sources:
|
||||
log_env_var = source + "_LOG_LEVEL"
|
||||
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
|
||||
if SRC_LOG_LEVELS[source] not in log_levels:
|
||||
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
|
||||
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
|
||||
|
||||
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
|
||||
|
||||
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
|
||||
|
||||
####################################
|
||||
# CUSTOM_NAME
|
||||
|
@ -165,7 +194,7 @@ if CUSTOM_NAME:
|
|||
|
||||
r = requests.get(url, stream=True)
|
||||
if r.status_code == 200:
|
||||
with open("./static/favicon.png", "wb") as f:
|
||||
with open(f"{STATIC_DIR}/favicon.png", "wb") as f:
|
||||
r.raw.decode_content = True
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
|
||||
|
@ -173,22 +202,7 @@ if CUSTOM_NAME:
|
|||
except Exception as e:
|
||||
log.exception(e)
|
||||
pass
|
||||
else:
|
||||
if WEBUI_NAME != "Open WebUI":
|
||||
WEBUI_NAME += " (Open WebUI)"
|
||||
|
||||
####################################
|
||||
# DATA/FRONTEND BUILD DIR
|
||||
####################################
|
||||
|
||||
DATA_DIR = str(Path(os.getenv("DATA_DIR", "./data")).resolve())
|
||||
FRONTEND_BUILD_DIR = str(Path(os.getenv("FRONTEND_BUILD_DIR", "../build")))
|
||||
|
||||
try:
|
||||
with open(f"{DATA_DIR}/config.json", "r") as f:
|
||||
CONFIG_DATA = json.load(f)
|
||||
except:
|
||||
CONFIG_DATA = {}
|
||||
|
||||
####################################
|
||||
# File Upload DIR
|
||||
|
@ -210,7 +224,7 @@ Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
|||
# Docs DIR
|
||||
####################################
|
||||
|
||||
DOCS_DIR = f"{DATA_DIR}/docs"
|
||||
DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
|
||||
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
|
@ -257,6 +271,7 @@ OLLAMA_API_BASE_URL = os.environ.get(
|
|||
|
||||
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
|
||||
K8S_FLAG = os.environ.get("K8S_FLAG", "")
|
||||
USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false")
|
||||
|
||||
if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "":
|
||||
OLLAMA_BASE_URL = (
|
||||
|
@ -266,9 +281,13 @@ if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "":
|
|||
)
|
||||
|
||||
if ENV == "prod":
|
||||
if OLLAMA_BASE_URL == "/ollama":
|
||||
if OLLAMA_BASE_URL == "/ollama" and not K8S_FLAG:
|
||||
if USE_OLLAMA_DOCKER.lower() == "true":
|
||||
# if you use all-in-one docker container (Open WebUI + Ollama)
|
||||
# with the docker build arg USE_OLLAMA=true (--build-arg="USE_OLLAMA=true") this only works with http://localhost:11434
|
||||
OLLAMA_BASE_URL = "http://localhost:11434"
|
||||
else:
|
||||
OLLAMA_BASE_URL = "http://host.docker.internal:11434"
|
||||
|
||||
elif K8S_FLAG:
|
||||
OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434"
|
||||
|
||||
|
@ -306,6 +325,18 @@ OPENAI_API_BASE_URLS = [
|
|||
for url in OPENAI_API_BASE_URLS.split(";")
|
||||
]
|
||||
|
||||
OPENAI_API_KEY = ""
|
||||
|
||||
try:
|
||||
OPENAI_API_KEY = OPENAI_API_KEYS[
|
||||
OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
|
||||
]
|
||||
except:
|
||||
pass
|
||||
|
||||
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
||||
|
||||
|
||||
####################################
|
||||
# WEBUI
|
||||
####################################
|
||||
|
@ -336,6 +367,17 @@ DEFAULT_PROMPT_SUGGESTIONS = (
|
|||
"title": ["Show me a code snippet", "of a website's sticky header"],
|
||||
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.",
|
||||
},
|
||||
{
|
||||
"title": [
|
||||
"Explain options trading",
|
||||
"if I'm familiar with buying and selling stocks",
|
||||
],
|
||||
"content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.",
|
||||
},
|
||||
{
|
||||
"title": ["Overcome procrastination", "give me tips"],
|
||||
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -348,13 +390,14 @@ USER_PERMISSIONS_CHAT_DELETION = (
|
|||
|
||||
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
|
||||
|
||||
|
||||
MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", "False").lower() == "true"
|
||||
ENABLE_MODEL_FILTER = os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true"
|
||||
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
|
||||
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
|
||||
|
||||
WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "")
|
||||
|
||||
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
|
||||
|
||||
####################################
|
||||
# WEBUI_VERSION
|
||||
####################################
|
||||
|
@ -389,21 +432,87 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
|
|||
####################################
|
||||
|
||||
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
||||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
|
||||
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
|
||||
# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
|
||||
RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
|
||||
"RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
|
||||
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
|
||||
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
|
||||
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
|
||||
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
|
||||
# Comma-separated list of header=value pairs
|
||||
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
|
||||
if CHROMA_HTTP_HEADERS:
|
||||
CHROMA_HTTP_HEADERS = dict(
|
||||
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
|
||||
)
|
||||
else:
|
||||
CHROMA_HTTP_HEADERS = None
|
||||
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
||||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
|
||||
|
||||
RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5"))
|
||||
RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0"))
|
||||
|
||||
ENABLE_RAG_HYBRID_SEARCH = (
|
||||
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
|
||||
)
|
||||
CHROMA_CLIENT = chromadb.PersistentClient(
|
||||
|
||||
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
|
||||
|
||||
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true"
|
||||
|
||||
RAG_EMBEDDING_MODEL = os.environ.get(
|
||||
"RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
|
||||
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
|
||||
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
)
|
||||
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
||||
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
||||
)
|
||||
|
||||
RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
|
||||
if not RAG_RERANKING_MODEL == "":
|
||||
log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
|
||||
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE = (
|
||||
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
)
|
||||
|
||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
|
||||
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
||||
)
|
||||
|
||||
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
|
||||
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
|
||||
|
||||
if USE_CUDA.lower() == "true":
|
||||
DEVICE_TYPE = "cuda"
|
||||
else:
|
||||
DEVICE_TYPE = "cpu"
|
||||
|
||||
if CHROMA_HTTP_HOST != "":
|
||||
CHROMA_CLIENT = chromadb.HttpClient(
|
||||
host=CHROMA_HTTP_HOST,
|
||||
port=CHROMA_HTTP_PORT,
|
||||
headers=CHROMA_HTTP_HEADERS,
|
||||
ssl=CHROMA_HTTP_SSL,
|
||||
tenant=CHROMA_TENANT,
|
||||
database=CHROMA_DATABASE,
|
||||
settings=Settings(allow_reset=True, anonymized_telemetry=False),
|
||||
)
|
||||
else:
|
||||
CHROMA_CLIENT = chromadb.PersistentClient(
|
||||
path=CHROMA_DATA_PATH,
|
||||
settings=Settings(allow_reset=True, anonymized_telemetry=False),
|
||||
)
|
||||
CHUNK_SIZE = 1500
|
||||
CHUNK_OVERLAP = 100
|
||||
tenant=CHROMA_TENANT,
|
||||
database=CHROMA_DATABASE,
|
||||
)
|
||||
|
||||
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500"))
|
||||
CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100"))
|
||||
|
||||
RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
<context>
|
||||
[context]
|
||||
</context>
|
||||
|
@ -417,17 +526,70 @@ And answer according to the language of the user's question.
|
|||
Given the context information, answer the query.
|
||||
Query: [query]"""
|
||||
|
||||
RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
|
||||
|
||||
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
|
||||
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
|
||||
|
||||
ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true"
|
||||
|
||||
####################################
|
||||
# Transcribe
|
||||
####################################
|
||||
|
||||
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
|
||||
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
|
||||
WHISPER_MODEL_AUTO_UPDATE = (
|
||||
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# Images
|
||||
####################################
|
||||
|
||||
IMAGE_GENERATION_ENGINE = os.getenv("IMAGE_GENERATION_ENGINE", "")
|
||||
|
||||
ENABLE_IMAGE_GENERATION = (
|
||||
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true"
|
||||
)
|
||||
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
|
||||
|
||||
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
|
||||
|
||||
IMAGES_OPENAI_API_BASE_URL = os.getenv(
|
||||
"IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL
|
||||
)
|
||||
IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY)
|
||||
|
||||
IMAGE_SIZE = os.getenv("IMAGE_SIZE", "512x512")
|
||||
|
||||
IMAGE_STEPS = int(os.getenv("IMAGE_STEPS", 50))
|
||||
|
||||
IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "")
|
||||
|
||||
####################################
|
||||
# Audio
|
||||
####################################
|
||||
|
||||
AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
|
||||
AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY)
|
||||
|
||||
####################################
|
||||
# LiteLLM
|
||||
####################################
|
||||
|
||||
|
||||
ENABLE_LITELLM = os.environ.get("ENABLE_LITELLM", "True").lower() == "true"
|
||||
|
||||
LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365"))
|
||||
if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535:
|
||||
raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
|
||||
LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")
|
||||
|
||||
|
||||
####################################
|
||||
# Database
|
||||
####################################
|
||||
|
||||
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
|
||||
|
|
|
@ -3,6 +3,10 @@ from enum import Enum
|
|||
|
||||
class MESSAGES(str, Enum):
|
||||
DEFAULT = lambda msg="": f"{msg if msg else ''}"
|
||||
MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully."
|
||||
MODEL_DELETED = (
|
||||
lambda model="": f"The model '{model}' has been deleted successfully."
|
||||
)
|
||||
|
||||
|
||||
class WEBHOOK_MESSAGES(str, Enum):
|
||||
|
@ -65,3 +69,9 @@ class ERROR_MESSAGES(str, Enum):
|
|||
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
|
||||
|
||||
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
|
||||
|
||||
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
|
||||
|
||||
INVALID_URL = (
|
||||
"Oops! The URL you provided is invalid. Please double-check and try again."
|
||||
)
|
||||
|
|
|
@ -18,6 +18,18 @@
|
|||
{
|
||||
"title": ["Show me a code snippet", "of a website's sticky header"],
|
||||
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript."
|
||||
},
|
||||
{
|
||||
"title": ["Explain options trading", "if I'm familiar with buying and selling stocks"],
|
||||
"content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks."
|
||||
},
|
||||
{
|
||||
"title": ["Overcome procrastination", "give me tips"],
|
||||
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?"
|
||||
},
|
||||
{
|
||||
"title": ["Grammar check", "rewrite it for better readability "],
|
||||
"content": "Check the following sentence for grammar and clarity: \"[sentence]\". Rewrite it for better readability while maintaining its original meaning."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
0
backend/dev.sh
Normal file → Executable file
0
backend/dev.sh
Normal file → Executable file
|
@ -5,6 +5,7 @@ import time
|
|||
import os
|
||||
import sys
|
||||
import logging
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from fastapi import FastAPI, Request, Depends, status
|
||||
|
@ -18,12 +19,18 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||
|
||||
from apps.ollama.main import app as ollama_app
|
||||
from apps.openai.main import app as openai_app
|
||||
from apps.litellm.main import app as litellm_app, startup as litellm_app_startup
|
||||
|
||||
from apps.litellm.main import (
|
||||
app as litellm_app,
|
||||
start_litellm_background,
|
||||
shutdown_litellm_background,
|
||||
)
|
||||
from apps.audio.main import app as audio_app
|
||||
from apps.images.main import app as images_app
|
||||
from apps.rag.main import app as rag_app
|
||||
from apps.web.main import app as webui_app
|
||||
|
||||
import asyncio
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
|
@ -38,11 +45,15 @@ from config import (
|
|||
VERSION,
|
||||
CHANGELOG,
|
||||
FRONTEND_BUILD_DIR,
|
||||
MODEL_FILTER_ENABLED,
|
||||
CACHE_DIR,
|
||||
STATIC_DIR,
|
||||
ENABLE_LITELLM,
|
||||
ENABLE_MODEL_FILTER,
|
||||
MODEL_FILTER_LIST,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
SRC_LOG_LEVELS,
|
||||
WEBHOOK_URL,
|
||||
ENABLE_ADMIN_EXPORT,
|
||||
)
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
|
@ -79,7 +90,7 @@ https://github.com/open-webui/open-webui
|
|||
|
||||
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
|
||||
|
||||
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
app.state.WEBHOOK_URL = WEBHOOK_URL
|
||||
|
@ -106,11 +117,14 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|||
if "docs" in data:
|
||||
data = {**data}
|
||||
data["messages"] = rag_messages(
|
||||
data["docs"],
|
||||
data["messages"],
|
||||
rag_app.state.RAG_TEMPLATE,
|
||||
rag_app.state.TOP_K,
|
||||
rag_app.state.sentence_transformer_ef,
|
||||
docs=data["docs"],
|
||||
messages=data["messages"],
|
||||
template=rag_app.state.RAG_TEMPLATE,
|
||||
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
||||
k=rag_app.state.TOP_K,
|
||||
reranking_function=rag_app.state.sentence_transformer_rf,
|
||||
r=rag_app.state.RELEVANCE_THRESHOLD,
|
||||
hybrid_search=rag_app.state.ENABLE_RAG_HYBRID_SEARCH,
|
||||
)
|
||||
del data["docs"]
|
||||
|
||||
|
@ -162,7 +176,8 @@ async def check_url(request: Request, call_next):
|
|||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
await litellm_app_startup()
|
||||
if ENABLE_LITELLM:
|
||||
asyncio.create_task(start_litellm_background())
|
||||
|
||||
|
||||
app.mount("/api/v1", webui_app)
|
||||
|
@ -194,13 +209,14 @@ async def get_app_config():
|
|||
"default_models": webui_app.state.DEFAULT_MODELS,
|
||||
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS,
|
||||
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
|
||||
"admin_export_enabled": ENABLE_ADMIN_EXPORT,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/config/model/filter")
|
||||
async def get_model_filter_config(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"enabled": app.state.MODEL_FILTER_ENABLED,
|
||||
"enabled": app.state.ENABLE_MODEL_FILTER,
|
||||
"models": app.state.MODEL_FILTER_LIST,
|
||||
}
|
||||
|
||||
|
@ -214,20 +230,20 @@ class ModelFilterConfigForm(BaseModel):
|
|||
async def update_model_filter_config(
|
||||
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
app.state.MODEL_FILTER_ENABLED = form_data.enabled
|
||||
app.state.ENABLE_MODEL_FILTER = form_data.enabled
|
||||
app.state.MODEL_FILTER_LIST = form_data.models
|
||||
|
||||
ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
|
||||
ollama_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
|
||||
ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
|
||||
|
||||
openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
|
||||
openai_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
|
||||
openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
|
||||
|
||||
litellm_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
|
||||
litellm_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
|
||||
litellm_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
|
||||
|
||||
return {
|
||||
"enabled": app.state.MODEL_FILTER_ENABLED,
|
||||
"enabled": app.state.ENABLE_MODEL_FILTER,
|
||||
"models": app.state.MODEL_FILTER_LIST,
|
||||
}
|
||||
|
||||
|
@ -269,14 +285,16 @@ async def get_app_changelog():
|
|||
@app.get("/api/version/updates")
|
||||
async def get_app_latest_release_version():
|
||||
try:
|
||||
response = requests.get(
|
||||
f"https://api.github.com/repos/open-webui/open-webui/releases/latest"
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
latest_version = response.json()["tag_name"]
|
||||
data = await response.json()
|
||||
latest_version = data["tag_name"]
|
||||
|
||||
return {"current": VERSION, "latest": latest_version[1:]}
|
||||
except Exception as e:
|
||||
except aiohttp.ClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
|
||||
|
@ -293,16 +311,26 @@ async def get_manifest_json():
|
|||
"background_color": "#343541",
|
||||
"theme_color": "#343541",
|
||||
"orientation": "portrait-primary",
|
||||
"icons": [{"src": "/favicon.png", "type": "image/png", "sizes": "844x884"}],
|
||||
"icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
|
||||
}
|
||||
|
||||
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
app.mount("/cache", StaticFiles(directory="data/cache"), name="cache")
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
|
||||
|
||||
|
||||
app.mount(
|
||||
if os.path.exists(FRONTEND_BUILD_DIR):
|
||||
app.mount(
|
||||
"/",
|
||||
SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
|
||||
name="spa-static-files",
|
||||
)
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only."
|
||||
)
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
if ENABLE_LITELLM:
|
||||
await shutdown_litellm_background()
|
||||
|
|
|
@ -1,53 +1,62 @@
|
|||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic
|
||||
python-multipart
|
||||
fastapi==0.109.2
|
||||
uvicorn[standard]==0.22.0
|
||||
pydantic==2.7.1
|
||||
python-multipart==0.0.9
|
||||
|
||||
flask
|
||||
flask_cors
|
||||
Flask==3.0.3
|
||||
Flask-Cors==4.0.0
|
||||
|
||||
python-socketio
|
||||
python-jose
|
||||
passlib[bcrypt]
|
||||
uuid
|
||||
python-socketio==5.11.2
|
||||
python-jose==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
uuid==1.30
|
||||
|
||||
requests
|
||||
aiohttp
|
||||
peewee
|
||||
peewee-migrate
|
||||
bcrypt
|
||||
requests==2.31.0
|
||||
aiohttp==3.9.5
|
||||
peewee==3.17.3
|
||||
peewee-migrate==1.12.2
|
||||
psycopg2-binary==2.9.9
|
||||
PyMySQL==1.1.0
|
||||
bcrypt==4.1.2
|
||||
|
||||
litellm==1.30.7
|
||||
boto3
|
||||
litellm==1.35.28
|
||||
litellm[proxy]==1.35.28
|
||||
|
||||
argon2-cffi
|
||||
apscheduler
|
||||
google-generativeai
|
||||
boto3==1.34.95
|
||||
|
||||
langchain
|
||||
langchain-community
|
||||
fake_useragent
|
||||
chromadb
|
||||
sentence_transformers
|
||||
pypdf
|
||||
docx2txt
|
||||
unstructured
|
||||
markdown
|
||||
pypandoc
|
||||
pandas
|
||||
openpyxl
|
||||
pyxlsb
|
||||
xlrd
|
||||
argon2-cffi==23.1.0
|
||||
APScheduler==3.10.4
|
||||
google-generativeai==0.5.2
|
||||
|
||||
opencv-python-headless
|
||||
rapidocr-onnxruntime
|
||||
langchain==0.1.16
|
||||
langchain-community==0.0.34
|
||||
langchain-chroma==0.1.0
|
||||
|
||||
fpdf2
|
||||
fake-useragent==1.5.1
|
||||
chromadb==0.4.24
|
||||
sentence-transformers==2.7.0
|
||||
pypdf==4.2.0
|
||||
docx2txt==0.8
|
||||
unstructured==0.11.8
|
||||
Markdown==3.6
|
||||
pypandoc==1.13
|
||||
pandas==2.2.2
|
||||
openpyxl==3.1.2
|
||||
pyxlsb==1.0.10
|
||||
xlrd==2.0.1
|
||||
validators==0.28.1
|
||||
|
||||
faster-whisper
|
||||
opencv-python-headless==4.9.0.80
|
||||
rapidocr-onnxruntime==1.2.3
|
||||
|
||||
PyJWT
|
||||
pyjwt[crypto]
|
||||
fpdf2==2.7.8
|
||||
rank-bm25==0.2.2
|
||||
|
||||
black
|
||||
langfuse
|
||||
faster-whisper==1.0.1
|
||||
|
||||
PyJWT==2.8.0
|
||||
PyJWT[crypto]==2.8.0
|
||||
|
||||
black==24.4.2
|
||||
langfuse==2.27.3
|
||||
youtube-transcript-api
|
||||
|
|
|
@ -6,17 +6,28 @@ cd "$SCRIPT_DIR" || exit
|
|||
KEY_FILE=.webui_secret_key
|
||||
|
||||
PORT="${PORT:-8080}"
|
||||
HOST="${HOST:-0.0.0.0}"
|
||||
if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
|
||||
echo No WEBUI_SECRET_KEY provided
|
||||
echo "No WEBUI_SECRET_KEY provided"
|
||||
|
||||
if ! [ -e "$KEY_FILE" ]; then
|
||||
echo Generating WEBUI_SECRET_KEY
|
||||
echo "Generating WEBUI_SECRET_KEY"
|
||||
# Generate a random value to use as a WEBUI_SECRET_KEY in case the user didn't provide one.
|
||||
echo $(head -c 12 /dev/random | base64) > $KEY_FILE
|
||||
echo $(head -c 12 /dev/random | base64) > "$KEY_FILE"
|
||||
fi
|
||||
|
||||
echo Loading WEBUI_SECRET_KEY from $KEY_FILE
|
||||
WEBUI_SECRET_KEY=`cat $KEY_FILE`
|
||||
echo "Loading WEBUI_SECRET_KEY from $KEY_FILE"
|
||||
WEBUI_SECRET_KEY=$(cat "$KEY_FILE")
|
||||
fi
|
||||
|
||||
WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn main:app --host 0.0.0.0 --port "$PORT" --forwarded-allow-ips '*'
|
||||
if [ "$USE_OLLAMA_DOCKER" = "true" ]; then
|
||||
echo "USE_OLLAMA is set to true, starting ollama serve."
|
||||
ollama serve &
|
||||
fi
|
||||
|
||||
if [ "$USE_CUDA_DOCKER" = "true" ]; then
|
||||
echo "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib"
|
||||
fi
|
||||
|
||||
WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn main:app --host "$HOST" --port "$PORT" --forwarded-allow-ips '*'
|
||||
|
|
|
@ -7,7 +7,7 @@ SET "SCRIPT_DIR=%~dp0"
|
|||
cd /d "%SCRIPT_DIR%" || exit /b
|
||||
|
||||
SET "KEY_FILE=.webui_secret_key"
|
||||
SET "PORT=%PORT:8080%"
|
||||
IF "%PORT%"=="" SET PORT=8080
|
||||
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
|
||||
SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"
|
||||
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 6 KiB After Width: | Height: | Size: 11 KiB |
BIN
backend/static/logo.png
Normal file
BIN
backend/static/logo.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 6 KiB |
1
backend/static/user-import.csv
Normal file
1
backend/static/user-import.csv
Normal file
|
@ -0,0 +1 @@
|
|||
Name,Email,Password,Role
|
|
BIN
backend/utils/logo.png
Normal file
BIN
backend/utils/logo.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 6 KiB |
|
@ -89,6 +89,8 @@ def get_current_user(
|
|||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
Users.update_user_last_active_by_id(user.id)
|
||||
return user
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
@ -99,11 +101,15 @@ def get_current_user(
|
|||
|
||||
def get_current_user_by_api_key(api_key: str):
|
||||
user = Users.get_user_by_api_key(api_key)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
Users.update_user_last_active_by_id(user.id)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,12 @@
|
|||
echo "Warning: This will remove all containers and volumes, including persistent data. Do you want to continue? [Y/N]"
|
||||
read ans
|
||||
if [ "$ans" == "Y" ] || [ "$ans" == "y" ]; then
|
||||
command docker-compose 2>/dev/null
|
||||
if [ "$?" == "0" ]; then
|
||||
docker-compose down -v
|
||||
else
|
||||
docker compose down -v
|
||||
fi
|
||||
else
|
||||
echo "Operation cancelled."
|
||||
fi
|
||||
|
|
8
cypress.config.ts
Normal file
8
cypress.config.ts
Normal file
|
@ -0,0 +1,8 @@
|
|||
import { defineConfig } from 'cypress';
|
||||
|
||||
export default defineConfig({
|
||||
e2e: {
|
||||
baseUrl: 'http://localhost:8080'
|
||||
},
|
||||
video: true
|
||||
});
|
46
cypress/e2e/chat.cy.ts
Normal file
46
cypress/e2e/chat.cy.ts
Normal file
|
@ -0,0 +1,46 @@
|
|||
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
|
||||
/// <reference path="../support/index.d.ts" />
|
||||
|
||||
// These tests run through the chat flow.
|
||||
describe('Settings', () => {
|
||||
// Wait for 2 seconds after all tests to fix an issue with Cypress's video recording missing the last few frames
|
||||
after(() => {
|
||||
// eslint-disable-next-line cypress/no-unnecessary-waiting
|
||||
cy.wait(2000);
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
// Login as the admin user
|
||||
cy.loginAdmin();
|
||||
// Visit the home page
|
||||
cy.visit('/');
|
||||
});
|
||||
|
||||
context('Ollama', () => {
|
||||
it('user can select a model', () => {
|
||||
// Click on the model selector
|
||||
cy.get('button[aria-label="Select a model"]').click();
|
||||
// Select the first model
|
||||
cy.get('button[aria-label="model-item"]').first().click();
|
||||
});
|
||||
|
||||
it('user can perform text chat', () => {
|
||||
// Click on the model selector
|
||||
cy.get('button[aria-label="Select a model"]').click();
|
||||
// Select the first model
|
||||
cy.get('button[aria-label="model-item"]').first().click();
|
||||
// Type a message
|
||||
cy.get('#chat-textarea').type('Hi, what can you do? A single sentence only please.', {
|
||||
force: true
|
||||
});
|
||||
// Send the message
|
||||
cy.get('button[type="submit"]').click();
|
||||
// User's message should be visible
|
||||
cy.get('.chat-user').should('exist');
|
||||
// Wait for the response
|
||||
cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received
|
||||
.find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received
|
||||
.should('exist');
|
||||
});
|
||||
});
|
||||
});
|
52
cypress/e2e/registration.cy.ts
Normal file
52
cypress/e2e/registration.cy.ts
Normal file
|
@ -0,0 +1,52 @@
|
|||
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
|
||||
/// <reference path="../support/index.d.ts" />
|
||||
import { adminUser } from '../support/e2e';
|
||||
|
||||
// These tests assume the following defaults:
|
||||
// 1. No users exist in the database or that the test admin user is an admin
|
||||
// 2. Language is set to English
|
||||
// 3. The default role for new users is 'pending'
|
||||
describe('Registration and Login', () => {
|
||||
// Wait for 2 seconds after all tests to fix an issue with Cypress's video recording missing the last few frames
|
||||
after(() => {
|
||||
// eslint-disable-next-line cypress/no-unnecessary-waiting
|
||||
cy.wait(2000);
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
cy.visit('/');
|
||||
});
|
||||
|
||||
it('should register a new user as pending', () => {
|
||||
const userName = `Test User - ${Date.now()}`;
|
||||
const userEmail = `cypress-${Date.now()}@example.com`;
|
||||
// Toggle from sign in to sign up
|
||||
cy.contains('Sign up').click();
|
||||
// Fill out the form
|
||||
cy.get('input[autocomplete="name"]').type(userName);
|
||||
cy.get('input[autocomplete="email"]').type(userEmail);
|
||||
cy.get('input[type="password"]').type('password');
|
||||
// Submit the form
|
||||
cy.get('button[type="submit"]').click();
|
||||
// Wait until the user is redirected to the home page
|
||||
cy.contains(userName);
|
||||
// Expect the user to be pending
|
||||
cy.contains('Check Again');
|
||||
});
|
||||
|
||||
it('can login with the admin user', () => {
|
||||
// Fill out the form
|
||||
cy.get('input[autocomplete="email"]').type(adminUser.email);
|
||||
cy.get('input[type="password"]').type(adminUser.password);
|
||||
// Submit the form
|
||||
cy.get('button[type="submit"]').click();
|
||||
// Wait until the user is redirected to the home page
|
||||
cy.contains(adminUser.name);
|
||||
// Dismiss the changelog dialog if it is visible
|
||||
cy.getAllLocalStorage().then((ls) => {
|
||||
if (!ls['version']) {
|
||||
cy.get('button').contains("Okay, Let's Go!").click();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
88
cypress/e2e/settings.cy.ts
Normal file
88
cypress/e2e/settings.cy.ts
Normal file
|
@ -0,0 +1,88 @@
|
|||
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
|
||||
/// <reference path="../support/index.d.ts" />
|
||||
import { adminUser } from '../support/e2e';
|
||||
|
||||
// These tests run through the various settings pages, ensuring that the user can interact with them as expected
|
||||
describe('Settings', () => {
|
||||
// Wait for 2 seconds after all tests to fix an issue with Cypress's video recording missing the last few frames
|
||||
after(() => {
|
||||
// eslint-disable-next-line cypress/no-unnecessary-waiting
|
||||
cy.wait(2000);
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
// Login as the admin user
|
||||
cy.loginAdmin();
|
||||
// Visit the home page
|
||||
cy.visit('/');
|
||||
// Open the sidebar if it is not already open
|
||||
cy.get('[aria-label="Open sidebar"]').then(() => {
|
||||
cy.get('button[id="sidebar-toggle-button"]').click();
|
||||
});
|
||||
// Click on the profile link
|
||||
cy.get('button').contains(adminUser.name).click();
|
||||
// Click on the settings link
|
||||
cy.get('button').contains('Settings').click();
|
||||
});
|
||||
|
||||
context('General', () => {
|
||||
it('user can open the General modal and hit save', () => {
|
||||
cy.get('button').contains('General').click();
|
||||
cy.get('button').contains('Save').click();
|
||||
});
|
||||
});
|
||||
|
||||
context('Connections', () => {
|
||||
it('user can open the Connections modal and hit save', () => {
|
||||
cy.get('button').contains('Connections').click();
|
||||
cy.get('button').contains('Save').click();
|
||||
});
|
||||
});
|
||||
|
||||
context('Models', () => {
|
||||
it('user can open the Models modal', () => {
|
||||
cy.get('button').contains('Models').click();
|
||||
});
|
||||
});
|
||||
|
||||
context('Interface', () => {
|
||||
it('user can open the Interface modal and hit save', () => {
|
||||
cy.get('button').contains('Interface').click();
|
||||
cy.get('button').contains('Save').click();
|
||||
});
|
||||
});
|
||||
|
||||
context('Audio', () => {
|
||||
it('user can open the Audio modal and hit save', () => {
|
||||
cy.get('button').contains('Audio').click();
|
||||
cy.get('button').contains('Save').click();
|
||||
});
|
||||
});
|
||||
|
||||
context('Images', () => {
|
||||
it('user can open the Images modal and hit save', () => {
|
||||
cy.get('button').contains('Images').click();
|
||||
// Currently fails because the backend requires a valid URL
|
||||
// cy.get('button').contains('Save').click();
|
||||
});
|
||||
});
|
||||
|
||||
context('Chats', () => {
|
||||
it('user can open the Chats modal', () => {
|
||||
cy.get('button').contains('Chats').click();
|
||||
});
|
||||
});
|
||||
|
||||
context('Account', () => {
|
||||
it('user can open the Account modal and hit save', () => {
|
||||
cy.get('button').contains('Account').click();
|
||||
cy.get('button').contains('Save').click();
|
||||
});
|
||||
});
|
||||
|
||||
context('About', () => {
|
||||
it('user can open the About modal', () => {
|
||||
cy.get('button').contains('About').click();
|
||||
});
|
||||
});
|
||||
});
|
73
cypress/support/e2e.ts
Normal file
73
cypress/support/e2e.ts
Normal file
|
@ -0,0 +1,73 @@
|
|||
/// <reference types="cypress" />
|
||||
|
||||
export const adminUser = {
|
||||
name: 'Admin User',
|
||||
email: 'admin@example.com',
|
||||
password: 'password'
|
||||
};
|
||||
|
||||
const login = (email: string, password: string) => {
|
||||
return cy.session(
|
||||
email,
|
||||
() => {
|
||||
// Visit auth page
|
||||
cy.visit('/auth');
|
||||
// Fill out the form
|
||||
cy.get('input[autocomplete="email"]').type(email);
|
||||
cy.get('input[type="password"]').type(password);
|
||||
// Submit the form
|
||||
cy.get('button[type="submit"]').click();
|
||||
// Wait until the user is redirected to the home page
|
||||
cy.get('#chat-search').should('exist');
|
||||
// Get the current version to skip the changelog dialog
|
||||
if (localStorage.getItem('version') === null) {
|
||||
cy.get('button').contains("Okay, Let's Go!").click();
|
||||
}
|
||||
},
|
||||
{
|
||||
validate: () => {
|
||||
cy.request({
|
||||
method: 'GET',
|
||||
url: '/api/v1/auths/',
|
||||
headers: {
|
||||
Authorization: 'Bearer ' + localStorage.getItem('token')
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
const register = (name: string, email: string, password: string) => {
|
||||
return cy
|
||||
.request({
|
||||
method: 'POST',
|
||||
url: '/api/v1/auths/signup',
|
||||
body: {
|
||||
name: name,
|
||||
email: email,
|
||||
password: password
|
||||
},
|
||||
failOnStatusCode: false
|
||||
})
|
||||
.then((response) => {
|
||||
expect(response.status).to.be.oneOf([200, 400]);
|
||||
});
|
||||
};
|
||||
|
||||
const registerAdmin = () => {
|
||||
return register(adminUser.name, adminUser.email, adminUser.password);
|
||||
};
|
||||
|
||||
const loginAdmin = () => {
|
||||
return login(adminUser.email, adminUser.password);
|
||||
};
|
||||
|
||||
Cypress.Commands.add('login', (email, password) => login(email, password));
|
||||
Cypress.Commands.add('register', (name, email, password) => register(name, email, password));
|
||||
Cypress.Commands.add('registerAdmin', () => registerAdmin());
|
||||
Cypress.Commands.add('loginAdmin', () => loginAdmin());
|
||||
|
||||
before(() => {
|
||||
cy.registerAdmin();
|
||||
});
|
11
cypress/support/index.d.ts
vendored
Normal file
11
cypress/support/index.d.ts
vendored
Normal file
|
@ -0,0 +1,11 @@
|
|||
// load the global Cypress types
|
||||
/// <reference types="cypress" />
|
||||
|
||||
declare namespace Cypress {
|
||||
interface Chainable {
|
||||
login(email: string, password: string): Chainable<Element>;
|
||||
register(name: string, email: string, password: string): Chainable<Element>;
|
||||
registerAdmin(): Chainable<Element>;
|
||||
loginAdmin(): Chainable<Element>;
|
||||
}
|
||||
}
|
7
cypress/tsconfig.json
Normal file
7
cypress/tsconfig.json
Normal file
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"extends": "../tsconfig.json",
|
||||
"compilerOptions": {
|
||||
"inlineSourceMap": true,
|
||||
"sourceMap": false
|
||||
}
|
||||
}
|
8
docker-compose.amdgpu.yaml
Normal file
8
docker-compose.amdgpu.yaml
Normal file
|
@ -0,0 +1,8 @@
|
|||
services:
|
||||
ollama:
|
||||
devices:
|
||||
- /dev/kfd:/dev/kfd
|
||||
- /dev/dri:/dev/dri
|
||||
image: ollama/ollama:${OLLAMA_DOCKER_TAG-rocm}
|
||||
environment:
|
||||
- 'HSA_OVERRIDE_GFX_VERSION=${HSA_OVERRIDE_GFX_VERSION-11.0.0}'
|
|
@ -8,7 +8,7 @@ services:
|
|||
pull_policy: always
|
||||
tty: true
|
||||
restart: unless-stopped
|
||||
image: ollama/ollama:latest
|
||||
image: ollama/ollama:${OLLAMA_DOCKER_TAG-latest}
|
||||
|
||||
open-webui:
|
||||
build:
|
||||
|
@ -16,7 +16,7 @@ services:
|
|||
args:
|
||||
OLLAMA_BASE_URL: '/ollama'
|
||||
dockerfile: Dockerfile
|
||||
image: ghcr.io/open-webui/open-webui:main
|
||||
image: ghcr.io/open-webui/open-webui:${WEBUI_DOCKER_TAG-main}
|
||||
container_name: open-webui
|
||||
volumes:
|
||||
- open-webui:/app/backend/data
|
||||
|
|
|
@ -7,7 +7,11 @@ ollama
|
|||
{{- end -}}
|
||||
|
||||
{{- define "ollama.url" -}}
|
||||
{{- printf "http://%s.%s.svc.cluster.local:%d/" (include "ollama.name" .) (.Release.Namespace) (.Values.ollama.service.port | int) }}
|
||||
{{- if .Values.ollama.externalHost }}
|
||||
{{- printf .Values.ollama.externalHost }}
|
||||
{{- else }}
|
||||
{{- printf "http://%s.%s.svc.cluster.local:%d" (include "ollama.name" .) (.Release.Namespace) (.Values.ollama.service.port | int) }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
{{- define "chart.name" -}}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
{{- if not .Values.ollama.externalHost }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
|
@ -19,3 +20,4 @@ spec:
|
|||
port: {{ .port }}
|
||||
targetPort: http
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
{{- if not .Values.ollama.externalHost }}
|
||||
apiVersion: apps/v1
|
||||
kind: StatefulSet
|
||||
metadata:
|
||||
|
@ -94,3 +95,4 @@ spec:
|
|||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
|
|
@ -17,7 +17,9 @@ spec:
|
|||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.webui.persistence.size }}
|
||||
{{- if .Values.webui.persistence.storageClass }}
|
||||
storageClassName: {{ .Values.webui.persistence.storageClass }}
|
||||
{{- end }}
|
||||
{{- with .Values.webui.persistence.selector }}
|
||||
selector:
|
||||
{{- toYaml . | nindent 4 }}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
nameOverride: ""
|
||||
|
||||
ollama:
|
||||
externalHost: ""
|
||||
annotations: {}
|
||||
podAnnotations: {}
|
||||
replicaCount: 1
|
||||
|
|
1782
package-lock.json
generated
1782
package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "open-webui",
|
||||
"version": "0.1.117",
|
||||
"version": "0.1.123",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "vite dev --host",
|
||||
|
@ -14,7 +14,8 @@
|
|||
"lint:backend": "pylint backend/",
|
||||
"format": "prettier --plugin-search-dir --write '**/*.{js,ts,svelte,css,md,html,json}'",
|
||||
"format:backend": "black . --exclude \"/venv/\"",
|
||||
"i18n:parse": "i18next --config i18next-parser.config.ts && prettier --write 'src/lib/i18n/**/*.{js,json}'"
|
||||
"i18n:parse": "i18next --config i18next-parser.config.ts && prettier --write 'src/lib/i18n/**/*.{js,json}'",
|
||||
"cy:open": "cypress open"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@sveltejs/adapter-auto": "^2.0.0",
|
||||
|
@ -25,8 +26,10 @@
|
|||
"@typescript-eslint/eslint-plugin": "^6.17.0",
|
||||
"@typescript-eslint/parser": "^6.17.0",
|
||||
"autoprefixer": "^10.4.16",
|
||||
"cypress": "^13.8.1",
|
||||
"eslint": "^8.56.0",
|
||||
"eslint-config-prettier": "^8.5.0",
|
||||
"eslint-plugin-cypress": "^3.0.2",
|
||||
"eslint-plugin-svelte": "^2.30.0",
|
||||
"i18next-parser": "^8.13.0",
|
||||
"postcss": "^8.4.31",
|
||||
|
@ -46,6 +49,7 @@
|
|||
"async": "^3.2.5",
|
||||
"bits-ui": "^0.19.7",
|
||||
"dayjs": "^1.11.10",
|
||||
"eventsource-parser": "^1.1.2",
|
||||
"file-saver": "^2.0.5",
|
||||
"highlight.js": "^11.9.0",
|
||||
"i18next": "^23.10.0",
|
||||
|
@ -53,7 +57,6 @@
|
|||
"i18next-resources-to-backend": "^1.2.0",
|
||||
"idb": "^7.1.1",
|
||||
"js-sha256": "^0.10.1",
|
||||
"jspdf": "^2.5.1",
|
||||
"katex": "^0.16.9",
|
||||
"marked": "^9.1.0",
|
||||
"svelte-sonner": "^0.3.19",
|
||||
|
|
|
@ -82,6 +82,7 @@ usage() {
|
|||
echo "Examples:"
|
||||
echo " $0 --drop"
|
||||
echo " $0 --enable-gpu[count=1]"
|
||||
echo " $0 --enable-gpu[count=all]"
|
||||
echo " $0 --enable-api[port=11435]"
|
||||
echo " $0 --enable-gpu[count=1] --enable-api[port=12345] --webui[port=3000]"
|
||||
echo " $0 --enable-gpu[count=1] --enable-api[port=12345] --webui[port=3000] --data[folder=./ollama-data]"
|
||||
|
@ -160,7 +161,7 @@ else
|
|||
if [[ $enable_gpu == true ]]; then
|
||||
# Validate and process command-line arguments
|
||||
if [[ -n $gpu_count ]]; then
|
||||
if ! [[ $gpu_count =~ ^[0-9]+$ ]]; then
|
||||
if ! [[ $gpu_count =~ ^([0-9]+|all)$ ]]; then
|
||||
echo "Invalid GPU count: $gpu_count"
|
||||
exit 1
|
||||
fi
|
||||
|
|
39
src/app.html
39
src/app.html
|
@ -3,7 +3,7 @@
|
|||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<link rel="icon" href="%sveltekit.assets%/favicon.png" />
|
||||
<link rel="manifest" href="%sveltekit.assets%/manifest.json" />
|
||||
<link rel="manifest" href="%sveltekit.assets%/manifest.json" crossorigin="use-credentials" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1" />
|
||||
<meta name="robots" content="noindex,nofollow" />
|
||||
<script>
|
||||
|
@ -43,9 +43,46 @@
|
|||
})();
|
||||
</script>
|
||||
|
||||
<title>Open WebUI</title>
|
||||
|
||||
%sveltekit.head%
|
||||
</head>
|
||||
<body data-sveltekit-preload-data="hover">
|
||||
<div style="display: contents">%sveltekit.body%</div>
|
||||
|
||||
<div
|
||||
id="splash-screen"
|
||||
style="
|
||||
position: fixed;
|
||||
z-index: 100;
|
||||
background: #fff;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
"
|
||||
>
|
||||
<style type="text/css" nonce="">
|
||||
html {
|
||||
overflow-y: scroll !important;
|
||||
}
|
||||
</style>
|
||||
|
||||
<img
|
||||
style="
|
||||
position: absolute;
|
||||
width: 6rem;
|
||||
height: 6rem;
|
||||
top: 46%;
|
||||
left: 50%;
|
||||
margin: -40px 0 0 -40px;
|
||||
"
|
||||
src="/logo.svg"
|
||||
/>
|
||||
|
||||
<!-- <span style="position: absolute; bottom: 32px; left: 50%; margin: -36px 0 0 -36px">
|
||||
Footer content
|
||||
</span> -->
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
|
|
|
@ -1,11 +1,73 @@
|
|||
import { AUDIO_API_BASE_URL } from '$lib/constants';
|
||||
|
||||
export const getAudioConfig = async (token: string) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${AUDIO_API_BASE_URL}/config`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
}
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
error = err.detail;
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
type OpenAIConfigForm = {
|
||||
url: string;
|
||||
key: string;
|
||||
};
|
||||
|
||||
export const updateAudioConfig = async (token: string, payload: OpenAIConfigForm) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${AUDIO_API_BASE_URL}/config/update`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...payload
|
||||
})
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
error = err.detail;
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const transcribeAudio = async (token: string, file: File) => {
|
||||
const data = new FormData();
|
||||
data.append('file', file);
|
||||
|
||||
let error = null;
|
||||
const res = await fetch(`${AUDIO_API_BASE_URL}/transcribe`, {
|
||||
const res = await fetch(`${AUDIO_API_BASE_URL}/transcriptions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
|
@ -29,3 +91,40 @@ export const transcribeAudio = async (token: string, file: File) => {
|
|||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const synthesizeOpenAISpeech = async (
|
||||
token: string = '',
|
||||
speaker: string = 'alloy',
|
||||
text: string = ''
|
||||
) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${AUDIO_API_BASE_URL}/speech`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: 'tts-1',
|
||||
input: text,
|
||||
voice: speaker
|
||||
})
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res;
|
||||
})
|
||||
.catch((err) => {
|
||||
error = err.detail;
|
||||
console.log(err);
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
|
|
@ -58,7 +58,12 @@ export const userSignIn = async (email: string, password: string) => {
|
|||
return res;
|
||||
};
|
||||
|
||||
export const userSignUp = async (name: string, email: string, password: string) => {
|
||||
export const userSignUp = async (
|
||||
name: string,
|
||||
email: string,
|
||||
password: string,
|
||||
profile_image_url: string
|
||||
) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/auths/signup`, {
|
||||
|
@ -69,7 +74,47 @@ export const userSignUp = async (name: string, email: string, password: string)
|
|||
body: JSON.stringify({
|
||||
name: name,
|
||||
email: email,
|
||||
password: password
|
||||
password: password,
|
||||
profile_image_url: profile_image_url
|
||||
})
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
error = err.detail;
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const addUser = async (
|
||||
token: string,
|
||||
name: string,
|
||||
email: string,
|
||||
password: string,
|
||||
role: string = 'pending'
|
||||
) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/auths/add`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...(token && { authorization: `Bearer ${token}` })
|
||||
},
|
||||
body: JSON.stringify({
|
||||
name: name,
|
||||
email: email,
|
||||
password: password,
|
||||
role: role
|
||||
})
|
||||
})
|
||||
.then(async (res) => {
|
||||
|
|
|
@ -62,6 +62,68 @@ export const getChatList = async (token: string = '') => {
|
|||
return res;
|
||||
};
|
||||
|
||||
export const getChatListByUserId = async (token: string = '', userId: string) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/list/user/${userId}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
...(token && { authorization: `Bearer ${token}` })
|
||||
}
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.then((json) => {
|
||||
return json;
|
||||
})
|
||||
.catch((err) => {
|
||||
error = err;
|
||||
console.log(err);
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const getArchivedChatList = async (token: string = '') => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/archived`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
...(token && { authorization: `Bearer ${token}` })
|
||||
}
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.then((json) => {
|
||||
return json;
|
||||
})
|
||||
.catch((err) => {
|
||||
error = err;
|
||||
console.log(err);
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const getAllChats = async (token: string) => {
|
||||
let error = null;
|
||||
|
||||
|
@ -282,6 +344,38 @@ export const shareChatById = async (token: string, id: string) => {
|
|||
return res;
|
||||
};
|
||||
|
||||
export const archiveChatById = async (token: string, id: string) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/archive`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
...(token && { authorization: `Bearer ${token}` })
|
||||
}
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.then((json) => {
|
||||
return json;
|
||||
})
|
||||
.catch((err) => {
|
||||
error = err;
|
||||
|
||||
console.log(err);
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const deleteSharedChatById = async (token: string, id: string) => {
|
||||
let error = null;
|
||||
|
||||
|
|
|
@ -72,10 +72,10 @@ export const updateImageGenerationConfig = async (
|
|||
return res;
|
||||
};
|
||||
|
||||
export const getOpenAIKey = async (token: string = '') => {
|
||||
export const getOpenAIConfig = async (token: string = '') => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${IMAGES_API_BASE_URL}/key`, {
|
||||
const res = await fetch(`${IMAGES_API_BASE_URL}/openai/config`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
|
@ -101,13 +101,13 @@ export const getOpenAIKey = async (token: string = '') => {
|
|||
throw error;
|
||||
}
|
||||
|
||||
return res.OPENAI_API_KEY;
|
||||
return res;
|
||||
};
|
||||
|
||||
export const updateOpenAIKey = async (token: string = '', key: string) => {
|
||||
export const updateOpenAIConfig = async (token: string = '', url: string, key: string) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${IMAGES_API_BASE_URL}/key/update`, {
|
||||
const res = await fetch(`${IMAGES_API_BASE_URL}/openai/config/update`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
|
@ -115,6 +115,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => {
|
|||
...(token && { authorization: `Bearer ${token}` })
|
||||
},
|
||||
body: JSON.stringify({
|
||||
url: url,
|
||||
key: key
|
||||
})
|
||||
})
|
||||
|
@ -136,7 +137,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => {
|
|||
throw error;
|
||||
}
|
||||
|
||||
return res.OPENAI_API_KEY;
|
||||
return res;
|
||||
};
|
||||
|
||||
export const getImageGenerationEngineUrls = async (token: string = '') => {
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import { OLLAMA_API_BASE_URL } from '$lib/constants';
|
||||
import { promptTemplate } from '$lib/utils';
|
||||
|
||||
export const getOllamaUrls = async (token: string = '') => {
|
||||
let error = null;
|
||||
|
@ -144,7 +145,7 @@ export const generateTitle = async (
|
|||
) => {
|
||||
let error = null;
|
||||
|
||||
template = template.replace(/{{prompt}}/g, prompt);
|
||||
template = promptTemplate(template, prompt);
|
||||
|
||||
console.log(template);
|
||||
|
||||
|
@ -219,6 +220,32 @@ export const generatePrompt = async (token: string = '', model: string, conversa
|
|||
return res;
|
||||
};
|
||||
|
||||
export const generateEmbeddings = async (token: string = '', model: string, text: string) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${OLLAMA_API_BASE_URL}/api/embeddings`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: model,
|
||||
prompt: text
|
||||
})
|
||||
}).catch((err) => {
|
||||
error = err;
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const generateTextCompletion = async (token: string = '', model: string, text: string) => {
|
||||
let error = null;
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import { OPENAI_API_BASE_URL } from '$lib/constants';
|
||||
import { promptTemplate } from '$lib/utils';
|
||||
|
||||
export const getOpenAIUrls = async (token: string = '') => {
|
||||
let error = null;
|
||||
|
@ -210,10 +211,12 @@ export const generateOpenAIChatCompletion = async (
|
|||
token: string = '',
|
||||
body: object,
|
||||
url: string = OPENAI_API_BASE_URL
|
||||
) => {
|
||||
): Promise<[Response | null, AbortController]> => {
|
||||
const controller = new AbortController();
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${url}/chat/completions`, {
|
||||
signal: controller.signal,
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
|
@ -230,7 +233,7 @@ export const generateOpenAIChatCompletion = async (
|
|||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
return [res, controller];
|
||||
};
|
||||
|
||||
export const synthesizeOpenAISpeech = async (
|
||||
|
@ -273,7 +276,7 @@ export const generateTitle = async (
|
|||
) => {
|
||||
let error = null;
|
||||
|
||||
template = template.replace(/{{prompt}}/g, prompt);
|
||||
template = promptTemplate(template, prompt);
|
||||
|
||||
console.log(template);
|
||||
|
||||
|
|
|
@ -123,6 +123,7 @@ export const getQuerySettings = async (token: string) => {
|
|||
|
||||
type QuerySettings = {
|
||||
k: number | null;
|
||||
r: number | null;
|
||||
template: string | null;
|
||||
};
|
||||
|
||||
|
@ -220,6 +221,37 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string
|
|||
return res;
|
||||
};
|
||||
|
||||
export const uploadYoutubeTranscriptionToVectorDB = async (token: string, url: string) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${RAG_API_BASE_URL}/youtube`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
authorization: `Bearer ${token}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
url: url
|
||||
})
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
error = err.detail;
|
||||
console.log(err);
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const queryDoc = async (
|
||||
token: string,
|
||||
collection_name: string,
|
||||
|
@ -345,3 +377,132 @@ export const resetVectorDB = async (token: string) => {
|
|||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const getEmbeddingConfig = async (token: string) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${RAG_API_BASE_URL}/embedding`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
}
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
error = err.detail;
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
type OpenAIConfigForm = {
|
||||
key: string;
|
||||
url: string;
|
||||
};
|
||||
|
||||
type EmbeddingModelUpdateForm = {
|
||||
openai_config?: OpenAIConfigForm;
|
||||
embedding_engine: string;
|
||||
embedding_model: string;
|
||||
};
|
||||
|
||||
export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${RAG_API_BASE_URL}/embedding/update`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...payload
|
||||
})
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
error = err.detail;
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const getRerankingConfig = async (token: string) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${RAG_API_BASE_URL}/reranking`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
}
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
error = err.detail;
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
type RerankingModelUpdateForm = {
|
||||
reranking_model: string;
|
||||
};
|
||||
|
||||
export const updateRerankingConfig = async (token: string, payload: RerankingModelUpdateForm) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${RAG_API_BASE_URL}/reranking/update`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...payload
|
||||
})
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
error = err.detail;
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
|
84
src/lib/apis/streaming/index.ts
Normal file
84
src/lib/apis/streaming/index.ts
Normal file
|
@ -0,0 +1,84 @@
|
|||
import { EventSourceParserStream } from 'eventsource-parser/stream';
|
||||
import type { ParsedEvent } from 'eventsource-parser';
|
||||
|
||||
type TextStreamUpdate = {
|
||||
done: boolean;
|
||||
value: string;
|
||||
};
|
||||
|
||||
// createOpenAITextStream takes a responseBody with a SSE response,
|
||||
// and returns an async generator that emits delta updates with large deltas chunked into random sized chunks
|
||||
export async function createOpenAITextStream(
|
||||
responseBody: ReadableStream<Uint8Array>,
|
||||
splitLargeDeltas: boolean
|
||||
): Promise<AsyncGenerator<TextStreamUpdate>> {
|
||||
const eventStream = responseBody
|
||||
.pipeThrough(new TextDecoderStream())
|
||||
.pipeThrough(new EventSourceParserStream())
|
||||
.getReader();
|
||||
let iterator = openAIStreamToIterator(eventStream);
|
||||
if (splitLargeDeltas) {
|
||||
iterator = streamLargeDeltasAsRandomChunks(iterator);
|
||||
}
|
||||
return iterator;
|
||||
}
|
||||
|
||||
async function* openAIStreamToIterator(
|
||||
reader: ReadableStreamDefaultReader<ParsedEvent>
|
||||
): AsyncGenerator<TextStreamUpdate> {
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) {
|
||||
yield { done: true, value: '' };
|
||||
break;
|
||||
}
|
||||
if (!value) {
|
||||
continue;
|
||||
}
|
||||
const data = value.data;
|
||||
if (data.startsWith('[DONE]')) {
|
||||
yield { done: true, value: '' };
|
||||
break;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsedData = JSON.parse(data);
|
||||
console.log(parsedData);
|
||||
|
||||
yield { done: false, value: parsedData.choices?.[0]?.delta?.content ?? '' };
|
||||
} catch (e) {
|
||||
console.error('Error extracting delta from SSE event:', e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// streamLargeDeltasAsRandomChunks will chunk large deltas (length > 5) into random sized chunks between 1-3 characters
|
||||
// This is to simulate a more fluid streaming, even though some providers may send large chunks of text at once
|
||||
async function* streamLargeDeltasAsRandomChunks(
|
||||
iterator: AsyncGenerator<TextStreamUpdate>
|
||||
): AsyncGenerator<TextStreamUpdate> {
|
||||
for await (const textStreamUpdate of iterator) {
|
||||
if (textStreamUpdate.done) {
|
||||
yield textStreamUpdate;
|
||||
return;
|
||||
}
|
||||
let content = textStreamUpdate.value;
|
||||
if (content.length < 5) {
|
||||
yield { done: false, value: content };
|
||||
continue;
|
||||
}
|
||||
while (content != '') {
|
||||
const chunkSize = Math.min(Math.floor(Math.random() * 3) + 1, content.length);
|
||||
const chunk = content.slice(0, chunkSize);
|
||||
yield { done: false, value: chunk };
|
||||
// Do not sleep if the tab is hidden
|
||||
// Timers are throttled to 1s in hidden tabs
|
||||
if (document?.visibilityState !== 'hidden') {
|
||||
await sleep(5);
|
||||
}
|
||||
content = content.slice(chunkSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
|
@ -83,9 +83,9 @@ export const downloadDatabase = async (token: string) => {
|
|||
Authorization: `Bearer ${token}`
|
||||
}
|
||||
})
|
||||
.then((response) => {
|
||||
.then(async (response) => {
|
||||
if (!response.ok) {
|
||||
throw new Error('Network response was not ok');
|
||||
throw await response.json();
|
||||
}
|
||||
return response.blob();
|
||||
})
|
||||
|
@ -100,7 +100,11 @@ export const downloadDatabase = async (token: string) => {
|
|||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
error = err;
|
||||
error = err.detail;
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
</script>
|
||||
|
||||
<Modal bind:show>
|
||||
<div class="px-5 py-4 dark:text-gray-300">
|
||||
<div class="px-5 py-4 dark:text-gray-300 text-gray-700">
|
||||
<div class="flex justify-between items-start">
|
||||
<div class="text-xl font-bold">
|
||||
{$i18n.t('What’s New in')}
|
||||
|
@ -32,6 +32,7 @@
|
|||
<button
|
||||
class="self-center"
|
||||
on:click={() => {
|
||||
localStorage.version = $config.version;
|
||||
show = false;
|
||||
}}
|
||||
>
|
||||
|
@ -58,7 +59,7 @@
|
|||
|
||||
<hr class=" dark:border-gray-800" />
|
||||
|
||||
<div class=" w-full p-4 px-5">
|
||||
<div class=" w-full p-4 px-5 text-gray-700 dark:text-gray-100">
|
||||
<div class=" overflow-y-scroll max-h-80">
|
||||
<div class="mb-3">
|
||||
{#if changelog}
|
||||
|
@ -109,7 +110,7 @@
|
|||
localStorage.version = $config.version;
|
||||
show = false;
|
||||
}}
|
||||
class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
|
||||
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
|
||||
>
|
||||
<span class="relative">{$i18n.t("Okay, Let's Go!")}</span>
|
||||
</button>
|
||||
|
|
334
src/lib/components/admin/AddUserModal.svelte
Normal file
334
src/lib/components/admin/AddUserModal.svelte
Normal file
|
@ -0,0 +1,334 @@
|
|||
<script lang="ts">
|
||||
import { toast } from 'svelte-sonner';
|
||||
import { createEventDispatcher } from 'svelte';
|
||||
import { onMount, getContext } from 'svelte';
|
||||
import { addUser } from '$lib/apis/auths';
|
||||
|
||||
import Modal from '../common/Modal.svelte';
|
||||
import { WEBUI_BASE_URL } from '$lib/constants';
|
||||
|
||||
const i18n = getContext('i18n');
|
||||
const dispatch = createEventDispatcher();
|
||||
|
||||
export let show = false;
|
||||
|
||||
let loading = false;
|
||||
let tab = '';
|
||||
let inputFiles;
|
||||
|
||||
let _user = {
|
||||
name: '',
|
||||
email: '',
|
||||
password: '',
|
||||
role: 'user'
|
||||
};
|
||||
|
||||
$: if (show) {
|
||||
_user = {
|
||||
name: '',
|
||||
email: '',
|
||||
password: '',
|
||||
role: 'user'
|
||||
};
|
||||
}
|
||||
|
||||
const submitHandler = async () => {
|
||||
const stopLoading = () => {
|
||||
dispatch('save');
|
||||
loading = false;
|
||||
};
|
||||
|
||||
if (tab === '') {
|
||||
loading = true;
|
||||
|
||||
const res = await addUser(
|
||||
localStorage.token,
|
||||
_user.name,
|
||||
_user.email,
|
||||
_user.password,
|
||||
_user.role
|
||||
).catch((error) => {
|
||||
toast.error(error);
|
||||
});
|
||||
|
||||
if (res) {
|
||||
stopLoading();
|
||||
show = false;
|
||||
}
|
||||
} else {
|
||||
if (inputFiles) {
|
||||
loading = true;
|
||||
|
||||
const file = inputFiles[0];
|
||||
const reader = new FileReader();
|
||||
|
||||
reader.onload = async (e) => {
|
||||
const csv = e.target.result;
|
||||
const rows = csv.split('\n');
|
||||
|
||||
let userCount = 0;
|
||||
|
||||
for (const [idx, row] of rows.entries()) {
|
||||
const columns = row.split(',').map((col) => col.trim());
|
||||
console.log(idx, columns);
|
||||
|
||||
if (idx > 0) {
|
||||
if (columns.length === 4 && ['admin', 'user', 'pending'].includes(columns[3])) {
|
||||
const res = await addUser(
|
||||
localStorage.token,
|
||||
columns[0],
|
||||
columns[1],
|
||||
columns[2],
|
||||
columns[3]
|
||||
).catch((error) => {
|
||||
toast.error(`Row ${idx + 1}: ${error}`);
|
||||
return null;
|
||||
});
|
||||
|
||||
if (res) {
|
||||
userCount = userCount + 1;
|
||||
}
|
||||
} else {
|
||||
toast.error(`Row ${idx + 1}: invalid format.`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toast.success(`Successfully imported ${userCount} users.`);
|
||||
inputFiles = null;
|
||||
const uploadInputElement = document.getElementById('upload-user-csv-input');
|
||||
|
||||
if (uploadInputElement) {
|
||||
uploadInputElement.value = null;
|
||||
}
|
||||
|
||||
stopLoading();
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
} else {
|
||||
toast.error(`File not found.`);
|
||||
}
|
||||
}
|
||||
};
|
||||
</script>
|
||||
|
||||
<Modal size="sm" bind:show>
|
||||
<div>
|
||||
<div class=" flex justify-between dark:text-gray-300 px-5 pt-4 pb-2">
|
||||
<div class=" text-lg font-medium self-center">{$i18n.t('Add User')}</div>
|
||||
<button
|
||||
class="self-center"
|
||||
on:click={() => {
|
||||
show = false;
|
||||
}}
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 20 20"
|
||||
fill="currentColor"
|
||||
class="w-5 h-5"
|
||||
>
|
||||
<path
|
||||
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col md:flex-row w-full px-5 pb-4 md:space-x-4 dark:text-gray-200">
|
||||
<div class=" flex flex-col w-full sm:flex-row sm:justify-center sm:space-x-6">
|
||||
<form
|
||||
class="flex flex-col w-full"
|
||||
on:submit|preventDefault={() => {
|
||||
submitHandler();
|
||||
}}
|
||||
>
|
||||
<div class="flex text-center text-sm font-medium rounded-xl bg-transparent/10 p-1 mb-2">
|
||||
<button
|
||||
class="w-full rounded-lg p-1.5 {tab === '' ? 'bg-gray-50 dark:bg-gray-850' : ''}"
|
||||
type="button"
|
||||
on:click={() => {
|
||||
tab = '';
|
||||
}}>Form</button
|
||||
>
|
||||
|
||||
<button
|
||||
class="w-full rounded-lg p-1 {tab === 'import' ? 'bg-gray-50 dark:bg-gray-850' : ''}"
|
||||
type="button"
|
||||
on:click={() => {
|
||||
tab = 'import';
|
||||
}}>CSV Import</button
|
||||
>
|
||||
</div>
|
||||
<div class="px-1">
|
||||
{#if tab === ''}
|
||||
<div class="flex flex-col w-full">
|
||||
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Role')}</div>
|
||||
|
||||
<div class="flex-1">
|
||||
<select
|
||||
class="w-full capitalize rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 disabled:text-gray-500 dark:disabled:text-gray-500 outline-none"
|
||||
bind:value={_user.role}
|
||||
placeholder={$i18n.t('Enter Your Role')}
|
||||
required
|
||||
>
|
||||
<option value="pending"> pending </option>
|
||||
<option value="user"> user </option>
|
||||
<option value="admin"> admin </option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col w-full mt-2">
|
||||
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Name')}</div>
|
||||
|
||||
<div class="flex-1">
|
||||
<input
|
||||
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 disabled:text-gray-500 dark:disabled:text-gray-500 outline-none"
|
||||
type="text"
|
||||
bind:value={_user.name}
|
||||
placeholder={$i18n.t('Enter Your Full Name')}
|
||||
autocomplete="off"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<hr class=" dark:border-gray-800 my-3 w-full" />
|
||||
|
||||
<div class="flex flex-col w-full">
|
||||
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Email')}</div>
|
||||
|
||||
<div class="flex-1">
|
||||
<input
|
||||
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 disabled:text-gray-500 dark:disabled:text-gray-500 outline-none"
|
||||
type="email"
|
||||
bind:value={_user.email}
|
||||
placeholder={$i18n.t('Enter Your Email')}
|
||||
autocomplete="off"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col w-full mt-2">
|
||||
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Password')}</div>
|
||||
|
||||
<div class="flex-1">
|
||||
<input
|
||||
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 disabled:text-gray-500 dark:disabled:text-gray-500 outline-none"
|
||||
type="password"
|
||||
bind:value={_user.password}
|
||||
placeholder={$i18n.t('Enter Your Password')}
|
||||
autocomplete="off"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{:else if tab === 'import'}
|
||||
<div>
|
||||
<div class="mb-3 w-full">
|
||||
<input
|
||||
id="upload-user-csv-input"
|
||||
hidden
|
||||
bind:files={inputFiles}
|
||||
type="file"
|
||||
accept=".csv"
|
||||
/>
|
||||
|
||||
<button
|
||||
class="w-full text-sm font-medium py-3 bg-transparent hover:bg-gray-100 border border-dashed dark:border-gray-800 dark:hover:bg-gray-850 text-center rounded-xl"
|
||||
type="button"
|
||||
on:click={() => {
|
||||
document.getElementById('upload-user-csv-input')?.click();
|
||||
}}
|
||||
>
|
||||
{#if inputFiles}
|
||||
{inputFiles.length > 0 ? `${inputFiles.length}` : ''} document(s) selected.
|
||||
{:else}
|
||||
{$i18n.t('Click here to select a csv file.')}
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class=" text-xs text-gray-500">
|
||||
ⓘ {$i18n.t(
|
||||
'Ensure your CSV file includes 4 columns in this order: Name, Email, Password, Role.'
|
||||
)}
|
||||
<a
|
||||
class="underline dark:text-gray-200"
|
||||
href="{WEBUI_BASE_URL}/static/user-import.csv"
|
||||
>
|
||||
Click here to download user import template file.
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="flex justify-end pt-3 text-sm font-medium">
|
||||
<button
|
||||
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg flex flex-row space-x-1 items-center {loading
|
||||
? ' cursor-not-allowed'
|
||||
: ''}"
|
||||
type="submit"
|
||||
disabled={loading}
|
||||
>
|
||||
{$i18n.t('Submit')}
|
||||
|
||||
{#if loading}
|
||||
<div class="ml-2 self-center">
|
||||
<svg
|
||||
class=" w-4 h-4"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
><style>
|
||||
.spinner_ajPY {
|
||||
transform-origin: center;
|
||||
animation: spinner_AtaB 0.75s infinite linear;
|
||||
}
|
||||
@keyframes spinner_AtaB {
|
||||
100% {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
</style><path
|
||||
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
|
||||
opacity=".25"
|
||||
/><path
|
||||
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
|
||||
class="spinner_ajPY"
|
||||
/></svg
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
|
||||
<style>
|
||||
input::-webkit-outer-spin-button,
|
||||
input::-webkit-inner-spin-button {
|
||||
/* display: none; <- Crashes Chrome on hover */
|
||||
-webkit-appearance: none;
|
||||
margin: 0; /* <-- Apparently some margin are still there even though it's hidden */
|
||||
}
|
||||
|
||||
.tabs::-webkit-scrollbar {
|
||||
display: none; /* for Chrome, Safari and Opera */
|
||||
}
|
||||
|
||||
.tabs {
|
||||
-ms-overflow-style: none; /* IE and Edge */
|
||||
scrollbar-width: none; /* Firefox */
|
||||
}
|
||||
|
||||
input[type='number'] {
|
||||
-moz-appearance: textfield; /* Firefox */
|
||||
}
|
||||
</style>
|
|
@ -86,7 +86,7 @@
|
|||
|
||||
<div class="text-xs text-gray-500">
|
||||
{$i18n.t('Created at')}
|
||||
{dayjs(selectedUser.timestamp * 1000).format($i18n.t('MMMM DD, YYYY'))}
|
||||
{dayjs(selectedUser.created_at * 1000).format($i18n.t('MMMM DD, YYYY'))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -139,7 +139,7 @@
|
|||
|
||||
<div class="flex justify-end pt-3 text-sm font-medium">
|
||||
<button
|
||||
class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
|
||||
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
|
||||
type="submit"
|
||||
>
|
||||
{$i18n.t('Save')}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
<script lang="ts">
|
||||
import { downloadDatabase } from '$lib/apis/utils';
|
||||
import { onMount, getContext } from 'svelte';
|
||||
import { config } from '$lib/stores';
|
||||
import { toast } from 'svelte-sonner';
|
||||
|
||||
const i18n = getContext('i18n');
|
||||
|
||||
|
@ -24,13 +26,16 @@
|
|||
<div class=" flex w-full justify-between">
|
||||
<!-- <div class=" self-center text-xs font-medium">{$i18n.t('Allow Chat Deletion')}</div> -->
|
||||
|
||||
{#if $config?.admin_export_enabled ?? true}
|
||||
<button
|
||||
class=" flex rounded-md py-1.5 px-3 w-full hover:bg-gray-200 dark:hover:bg-gray-800 transition"
|
||||
type="button"
|
||||
on:click={() => {
|
||||
// exportAllUserChats();
|
||||
|
||||
downloadDatabase(localStorage.token);
|
||||
downloadDatabase(localStorage.token).catch((error) => {
|
||||
toast.error(error);
|
||||
});
|
||||
}}
|
||||
>
|
||||
<div class=" self-center mr-3">
|
||||
|
@ -50,16 +55,18 @@
|
|||
</div>
|
||||
<div class=" self-center text-sm font-medium">{$i18n.t('Download Database')}</div>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- <div class="flex justify-end pt-3 text-sm font-medium">
|
||||
<button
|
||||
class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
|
||||
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
|
||||
type="submit"
|
||||
>
|
||||
Save
|
||||
</button>
|
||||
|
||||
</div> -->
|
||||
</form>
|
||||
|
|
|
@ -159,7 +159,7 @@
|
|||
|
||||
<div class="flex justify-end pt-3 text-sm font-medium">
|
||||
<button
|
||||
class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
|
||||
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
|
||||
type="submit"
|
||||
>
|
||||
{$i18n.t('Save')}
|
||||
|
|
|
@ -190,7 +190,7 @@
|
|||
|
||||
<div class="flex justify-end pt-3 text-sm font-medium">
|
||||
<button
|
||||
class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
|
||||
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
|
||||
type="submit"
|
||||
>
|
||||
{$i18n.t('Save')}
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
<Modal bind:show>
|
||||
<div>
|
||||
<div class=" flex justify-between dark:text-gray-300 px-5 py-4">
|
||||
<div class=" flex justify-between dark:text-gray-300 px-5 pt-4 pb-2">
|
||||
<div class=" text-lg font-medium self-center">{$i18n.t('Admin Settings')}</div>
|
||||
<button
|
||||
class="self-center"
|
||||
|
@ -35,7 +35,6 @@
|
|||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<hr class=" dark:border-gray-800" />
|
||||
|
||||
<div class="flex flex-col md:flex-row w-full p-4 md:space-x-4">
|
||||
<div
|
||||
|
|
144
src/lib/components/admin/UserChatsModal.svelte
Normal file
144
src/lib/components/admin/UserChatsModal.svelte
Normal file
|
@ -0,0 +1,144 @@
|
|||
<script lang="ts">
|
||||
import { toast } from 'svelte-sonner';
|
||||
import dayjs from 'dayjs';
|
||||
import { getContext, createEventDispatcher } from 'svelte';
|
||||
|
||||
const dispatch = createEventDispatcher();
|
||||
|
||||
import Modal from '$lib/components/common/Modal.svelte';
|
||||
import { getChatListByUserId, deleteChatById, getArchivedChatList } from '$lib/apis/chats';
|
||||
import Tooltip from '$lib/components/common/Tooltip.svelte';
|
||||
|
||||
const i18n = getContext('i18n');
|
||||
|
||||
export let show = false;
|
||||
export let user;
|
||||
|
||||
let chats = [];
|
||||
|
||||
const deleteChatHandler = async (chatId) => {
|
||||
const res = await deleteChatById(localStorage.token, chatId).catch((error) => {
|
||||
toast.error(error);
|
||||
});
|
||||
|
||||
chats = await getChatListByUserId(localStorage.token, user.id);
|
||||
};
|
||||
|
||||
$: if (show) {
|
||||
(async () => {
|
||||
if (user.id) {
|
||||
chats = await getChatListByUserId(localStorage.token, user.id);
|
||||
}
|
||||
})();
|
||||
}
|
||||
</script>
|
||||
|
||||
<Modal size="lg" bind:show>
|
||||
<div>
|
||||
<div class=" flex justify-between dark:text-gray-300 px-5 py-4">
|
||||
<div class=" text-lg font-medium self-center capitalize">
|
||||
{$i18n.t("{{user}}'s Chats", { user: user.name })}
|
||||
</div>
|
||||
<button
|
||||
class="self-center"
|
||||
on:click={() => {
|
||||
show = false;
|
||||
}}
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 20 20"
|
||||
fill="currentColor"
|
||||
class="w-5 h-5"
|
||||
>
|
||||
<path
|
||||
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<hr class=" dark:border-gray-850" />
|
||||
|
||||
<div class="flex flex-col md:flex-row w-full px-5 py-4 md:space-x-4 dark:text-gray-200">
|
||||
<div class=" flex flex-col w-full sm:flex-row sm:justify-center sm:space-x-6">
|
||||
{#if chats.length > 0}
|
||||
<div class="text-left text-sm w-full mb-4 max-h-[22rem] overflow-y-scroll">
|
||||
<div class="relative overflow-x-auto">
|
||||
<table class="w-full text-sm text-left text-gray-600 dark:text-gray-400 table-auto">
|
||||
<thead
|
||||
class="text-xs text-gray-700 uppercase bg-transparent dark:text-gray-200 border-b-2 dark:border-gray-800"
|
||||
>
|
||||
<tr>
|
||||
<th scope="col" class="px-3 py-2"> {$i18n.t('Name')} </th>
|
||||
<th scope="col" class="px-3 py-2 hidden md:flex"> {$i18n.t('Created at')} </th>
|
||||
<th scope="col" class="px-3 py-2 text-right" />
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#each chats as chat, idx}
|
||||
<tr
|
||||
class="bg-transparent {idx !== chats.length - 1 &&
|
||||
'border-b'} dark:bg-gray-900 dark:border-gray-850 text-xs"
|
||||
>
|
||||
<td class="px-3 py-1 w-2/3">
|
||||
<a href="/s/{chat.id}" target="_blank">
|
||||
<div class=" underline line-clamp-1">
|
||||
{chat.title}
|
||||
</div>
|
||||
</a>
|
||||
</td>
|
||||
|
||||
<td class=" px-3 py-1 hidden md:flex h-[2.5rem]">
|
||||
<div class="my-auto">
|
||||
{dayjs(chat.created_at * 1000).format($i18n.t('MMMM DD, YYYY HH:mm'))}
|
||||
</div>
|
||||
</td>
|
||||
|
||||
<td class="px-3 py-1 text-right">
|
||||
<div class="flex justify-end w-full">
|
||||
<Tooltip content={$i18n.t('Delete Chat')}>
|
||||
<button
|
||||
class="self-center w-fit text-sm px-2 py-2 hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
|
||||
on:click={async () => {
|
||||
deleteChatHandler(chat.id);
|
||||
}}
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke-width="1.5"
|
||||
stroke="currentColor"
|
||||
class="w-4 h-4"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="m14.74 9-.346 9m-4.788 0L9.26 9m9.968-3.21c.342.052.682.107 1.022.166m-1.022-.165L18.16 19.673a2.25 2.25 0 0 1-2.244 2.077H8.084a2.25 2.25 0 0 1-2.244-2.077L4.772 5.79m14.456 0a48.108 48.108 0 0 0-3.478-.397m-12 .562c.34-.059.68-.114 1.022-.165m0 0a48.11 48.11 0 0 1 3.478-.397m7.5 0v-.916c0-1.18-.91-2.164-2.09-2.201a51.964 51.964 0 0 0-3.32 0c-1.18.037-2.09 1.022-2.09 2.201v.916m7.5 0a48.667 48.667 0 0 0-7.5 0"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
{/each}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- {#each chats as chat}
|
||||
<div>
|
||||
{JSON.stringify(chat)}
|
||||
</div>
|
||||
{/each} -->
|
||||
</div>
|
||||
{:else}
|
||||
<div class="text-left text-sm w-full mb-8">
|
||||
{user.name}
|
||||
{$i18n.t('has no conversations.')}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
|
@ -1,26 +1,34 @@
|
|||
<script lang="ts">
|
||||
import { toast } from 'svelte-sonner';
|
||||
import { onMount, tick, getContext } from 'svelte';
|
||||
import { settings } from '$lib/stores';
|
||||
import { modelfiles, settings, showSidebar } from '$lib/stores';
|
||||
import { blobToFile, calculateSHA256, findWordIndices } from '$lib/utils';
|
||||
|
||||
import {
|
||||
uploadDocToVectorDB,
|
||||
uploadWebToVectorDB,
|
||||
uploadYoutubeTranscriptionToVectorDB
|
||||
} from '$lib/apis/rag';
|
||||
import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS, WEBUI_BASE_URL } from '$lib/constants';
|
||||
|
||||
import { transcribeAudio } from '$lib/apis/audio';
|
||||
|
||||
import Prompts from './MessageInput/PromptCommands.svelte';
|
||||
import Suggestions from './MessageInput/Suggestions.svelte';
|
||||
import { uploadDocToVectorDB, uploadWebToVectorDB } from '$lib/apis/rag';
|
||||
import AddFilesPlaceholder from '../AddFilesPlaceholder.svelte';
|
||||
import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants';
|
||||
import Documents from './MessageInput/Documents.svelte';
|
||||
import Models from './MessageInput/Models.svelte';
|
||||
import { transcribeAudio } from '$lib/apis/audio';
|
||||
import Tooltip from '../common/Tooltip.svelte';
|
||||
import XMark from '$lib/components/icons/XMark.svelte';
|
||||
|
||||
const i18n = getContext('i18n');
|
||||
|
||||
export let submitPrompt: Function;
|
||||
export let stopResponse: Function;
|
||||
|
||||
export let suggestionPrompts = [];
|
||||
export let autoScroll = true;
|
||||
export let selectedModel = '';
|
||||
|
||||
let chatTextAreaElement: HTMLTextAreaElement;
|
||||
let filesInputElement;
|
||||
|
||||
|
@ -290,11 +298,47 @@
|
|||
}
|
||||
};
|
||||
|
||||
const uploadYoutubeTranscription = async (url) => {
|
||||
console.log(url);
|
||||
|
||||
const doc = {
|
||||
type: 'doc',
|
||||
name: url,
|
||||
collection_name: '',
|
||||
upload_status: false,
|
||||
url: url,
|
||||
error: ''
|
||||
};
|
||||
|
||||
try {
|
||||
files = [...files, doc];
|
||||
const res = await uploadYoutubeTranscriptionToVectorDB(localStorage.token, url);
|
||||
|
||||
if (res) {
|
||||
doc.upload_status = true;
|
||||
doc.collection_name = res.collection_name;
|
||||
files = files;
|
||||
}
|
||||
} catch (e) {
|
||||
// Remove the failed doc from the files array
|
||||
files = files.filter((f) => f.name !== url);
|
||||
toast.error(e);
|
||||
}
|
||||
};
|
||||
|
||||
onMount(() => {
|
||||
console.log(document.getElementById('sidebar'));
|
||||
window.setTimeout(() => chatTextAreaElement?.focus(), 0);
|
||||
|
||||
const dropZone = document.querySelector('body');
|
||||
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (event.key === 'Escape') {
|
||||
console.log('Escape');
|
||||
dragged = false;
|
||||
}
|
||||
};
|
||||
|
||||
const onDragOver = (e) => {
|
||||
e.preventDefault();
|
||||
dragged = true;
|
||||
|
@ -309,8 +353,13 @@
|
|||
console.log(e);
|
||||
|
||||
if (e.dataTransfer?.files) {
|
||||
let reader = new FileReader();
|
||||
const inputFiles = Array.from(e.dataTransfer?.files);
|
||||
|
||||
if (inputFiles && inputFiles.length > 0) {
|
||||
inputFiles.forEach((file) => {
|
||||
console.log(file, file.name.split('.').at(-1));
|
||||
if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) {
|
||||
let reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
files = [
|
||||
...files,
|
||||
|
@ -320,13 +369,6 @@
|
|||
}
|
||||
];
|
||||
};
|
||||
|
||||
const inputFiles = e.dataTransfer?.files;
|
||||
|
||||
if (inputFiles && inputFiles.length > 0) {
|
||||
const file = inputFiles[0];
|
||||
console.log(file, file.name.split('.').at(-1));
|
||||
if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) {
|
||||
reader.readAsDataURL(file);
|
||||
} else if (
|
||||
SUPPORTED_FILE_TYPE.includes(file['type']) ||
|
||||
|
@ -342,6 +384,7 @@
|
|||
);
|
||||
uploadDoc(file);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
toast.error($i18n.t(`File not found.`));
|
||||
}
|
||||
|
@ -350,11 +393,15 @@
|
|||
dragged = false;
|
||||
};
|
||||
|
||||
window.addEventListener('keydown', handleKeyDown);
|
||||
|
||||
dropZone?.addEventListener('dragover', onDragOver);
|
||||
dropZone?.addEventListener('drop', onDrop);
|
||||
dropZone?.addEventListener('dragleave', onDragLeave);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('keydown', handleKeyDown);
|
||||
|
||||
dropZone?.removeEventListener('dragover', onDragOver);
|
||||
dropZone?.removeEventListener('drop', onDrop);
|
||||
dropZone?.removeEventListener('dragleave', onDragLeave);
|
||||
|
@ -364,7 +411,9 @@
|
|||
|
||||
{#if dragged}
|
||||
<div
|
||||
class="fixed lg:w-[calc(100%-260px)] w-full h-full flex z-50 touch-none pointer-events-none"
|
||||
class="fixed {$showSidebar
|
||||
? 'left-0 lg:left-[260px] lg:w-[calc(100%-260px)]'
|
||||
: 'left-0'} w-full h-full flex z-50 touch-none pointer-events-none"
|
||||
id="dropzone"
|
||||
role="region"
|
||||
aria-label="Drag and Drop Container"
|
||||
|
@ -379,12 +428,13 @@
|
|||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="w-full">
|
||||
<div class="px-2.5 -mb-0.5 mx-auto inset-x-0 bg-transparent flex justify-center">
|
||||
<div class="flex flex-col max-w-3xl w-full">
|
||||
<div class="fixed bottom-0 {$showSidebar ? 'left-0 lg:left-[260px]' : 'left-0'} right-0">
|
||||
<div class="w-full">
|
||||
<div class="px-2.5 lg:px-16 -mb-0.5 mx-auto inset-x-0 bg-transparent flex justify-center">
|
||||
<div class="flex flex-col max-w-5xl w-full">
|
||||
<div class="relative">
|
||||
{#if autoScroll === false && messages.length > 0}
|
||||
<div class=" absolute -top-12 left-0 right-0 flex justify-center">
|
||||
<div class=" absolute -top-12 left-0 right-0 flex justify-center z-30">
|
||||
<button
|
||||
class=" bg-white border border-gray-100 dark:border-none dark:bg-white/20 p-1.5 rounded-full"
|
||||
on:click={() => {
|
||||
|
@ -416,6 +466,10 @@
|
|||
<Documents
|
||||
bind:this={documentsElement}
|
||||
bind:prompt
|
||||
on:youtube={(e) => {
|
||||
console.log(e);
|
||||
uploadYoutubeTranscription(e.detail);
|
||||
}}
|
||||
on:url={(e) => {
|
||||
console.log(e);
|
||||
uploadWeb(e.detail);
|
||||
|
@ -432,31 +486,68 @@
|
|||
];
|
||||
}}
|
||||
/>
|
||||
{:else if prompt.charAt(0) === '@'}
|
||||
{/if}
|
||||
|
||||
<Models
|
||||
bind:this={modelsElement}
|
||||
bind:prompt
|
||||
bind:user
|
||||
bind:chatInputPlaceholder
|
||||
{messages}
|
||||
on:select={(e) => {
|
||||
selectedModel = e.detail;
|
||||
chatTextAreaElement?.focus();
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
{#if messages.length == 0 && suggestionPrompts.length !== 0}
|
||||
<Suggestions {suggestionPrompts} {submitPrompt} />
|
||||
{#if selectedModel !== ''}
|
||||
<div
|
||||
class="px-3 py-2.5 text-left w-full flex justify-between items-center absolute bottom-0 left-0 right-0 bg-gradient-to-t from-50% from-white dark:from-gray-900"
|
||||
>
|
||||
<div class="flex items-center gap-2 text-sm dark:text-gray-500">
|
||||
<img
|
||||
alt="model profile"
|
||||
class="size-5 max-w-[28px] object-cover rounded-full"
|
||||
src={$modelfiles.find((modelfile) => modelfile.tagName === selectedModel.id)
|
||||
?.imageUrl ??
|
||||
($i18n.language === 'dg-DG'
|
||||
? `/doge.png`
|
||||
: `${WEBUI_BASE_URL}/static/favicon.png`)}
|
||||
/>
|
||||
<div>
|
||||
Talking to <span class=" font-medium">{selectedModel.name} </span>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<button
|
||||
class="flex items-center"
|
||||
on:click={() => {
|
||||
selectedModel = '';
|
||||
}}
|
||||
>
|
||||
<XMark />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="bg-white dark:bg-gray-900">
|
||||
<div class="max-w-3xl px-2.5 mx-auto inset-x-0">
|
||||
<div class="max-w-6xl px-2.5 lg:px-16 mx-auto inset-x-0">
|
||||
<div class=" pb-2">
|
||||
<input
|
||||
bind:this={filesInputElement}
|
||||
bind:files={inputFiles}
|
||||
type="file"
|
||||
hidden
|
||||
multiple
|
||||
on:change={async () => {
|
||||
if (inputFiles && inputFiles.length > 0) {
|
||||
const _inputFiles = Array.from(inputFiles);
|
||||
_inputFiles.forEach((file) => {
|
||||
if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) {
|
||||
let reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
files = [
|
||||
|
@ -469,10 +560,6 @@
|
|||
inputFiles = null;
|
||||
filesInputElement.value = '';
|
||||
};
|
||||
|
||||
if (inputFiles && inputFiles.length > 0) {
|
||||
const file = inputFiles[0];
|
||||
if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) {
|
||||
reader.readAsDataURL(file);
|
||||
} else if (
|
||||
SUPPORTED_FILE_TYPE.includes(file['type']) ||
|
||||
|
@ -490,6 +577,7 @@
|
|||
uploadDoc(file);
|
||||
filesInputElement.value = '';
|
||||
}
|
||||
});
|
||||
} else {
|
||||
toast.error($i18n.t(`File not found.`));
|
||||
}
|
||||
|
@ -666,7 +754,7 @@
|
|||
<textarea
|
||||
id="chat-textarea"
|
||||
bind:this={chatTextAreaElement}
|
||||
class=" dark:bg-gray-900 dark:text-gray-100 outline-none w-full py-3 px-3 {fileUploadEnabled
|
||||
class="scrollbar-none dark:bg-gray-900 dark:text-gray-100 outline-none w-full py-3 px-3 {fileUploadEnabled
|
||||
? ''
|
||||
: ' pl-4'} rounded-xl resize-none h-[48px]"
|
||||
placeholder={chatInputPlaceholder !== ''
|
||||
|
@ -676,12 +764,21 @@
|
|||
: $i18n.t('Send a Message')}
|
||||
bind:value={prompt}
|
||||
on:keypress={(e) => {
|
||||
if (
|
||||
window.innerWidth > 1024 ||
|
||||
!(
|
||||
'ontouchstart' in window ||
|
||||
navigator.maxTouchPoints > 0 ||
|
||||
navigator.msMaxTouchPoints > 0
|
||||
)
|
||||
) {
|
||||
if (e.keyCode == 13 && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
}
|
||||
if (prompt !== '' && e.keyCode == 13 && !e.shiftKey) {
|
||||
submitPrompt(prompt, user);
|
||||
}
|
||||
}
|
||||
}}
|
||||
on:keydown={async (e) => {
|
||||
const isCtrlPressed = e.ctrlKey || e.metaKey; // metaKey is for Cmd key on Mac
|
||||
|
@ -744,7 +841,11 @@
|
|||
...document.getElementsByClassName('selected-command-option-button')
|
||||
]?.at(-1);
|
||||
|
||||
if (commandOptionButton) {
|
||||
commandOptionButton?.click();
|
||||
} else {
|
||||
document.getElementById('send-message-button')?.click();
|
||||
}
|
||||
}
|
||||
|
||||
if (['/', '#', '@'].includes(prompt.charAt(0)) && e.key === 'Tab') {
|
||||
|
@ -772,6 +873,14 @@
|
|||
e.preventDefault();
|
||||
e.target.setSelectionRange(word?.startIndex, word.endIndex + 1);
|
||||
}
|
||||
|
||||
e.target.style.height = '';
|
||||
e.target.style.height = Math.min(e.target.scrollHeight, 200) + 'px';
|
||||
}
|
||||
|
||||
if (e.key === 'Escape') {
|
||||
console.log('Escape');
|
||||
selectedModel = '';
|
||||
}
|
||||
}}
|
||||
rows="1"
|
||||
|
@ -883,6 +992,7 @@
|
|||
|
||||
<Tooltip content={$i18n.t('Send message')}>
|
||||
<button
|
||||
id="send-message-button"
|
||||
class="{prompt !== ''
|
||||
? 'bg-black text-white hover:bg-gray-900 dark:bg-white dark:text-black dark:hover:bg-gray-100 '
|
||||
: 'text-white bg-gray-100 dark:text-gray-900 dark:bg-gray-800 disabled'} transition rounded-full p-1.5 self-center"
|
||||
|
@ -932,4 +1042,16 @@
|
|||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.scrollbar-none:active::-webkit-scrollbar-thumb,
|
||||
.scrollbar-none:focus::-webkit-scrollbar-thumb,
|
||||
.scrollbar-none:hover::-webkit-scrollbar-thumb {
|
||||
visibility: visible;
|
||||
}
|
||||
.scrollbar-none::-webkit-scrollbar-thumb {
|
||||
visibility: hidden;
|
||||
}
|
||||
</style>
|
||||
|
|
|
@ -87,6 +87,17 @@
|
|||
chatInputElement?.focus();
|
||||
await tick();
|
||||
};
|
||||
|
||||
const confirmSelectYoutube = async (url) => {
|
||||
dispatch('youtube', url);
|
||||
|
||||
prompt = removeFirstHashWord(prompt);
|
||||
const chatInputElement = document.getElementById('chat-textarea');
|
||||
|
||||
await tick();
|
||||
chatInputElement?.focus();
|
||||
await tick();
|
||||
};
|
||||
</script>
|
||||
|
||||
{#if filteredItems.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
|
||||
|
@ -132,7 +143,30 @@
|
|||
</button>
|
||||
{/each}
|
||||
|
||||
{#if prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
|
||||
{#if prompt.split(' ')?.at(0)?.substring(1).startsWith('https://www.youtube.com')}
|
||||
<button
|
||||
class="px-3 py-1.5 rounded-xl w-full text-left bg-gray-100 selected-command-option-button"
|
||||
type="button"
|
||||
on:click={() => {
|
||||
const url = prompt.split(' ')?.at(0)?.substring(1);
|
||||
if (isValidHttpUrl(url)) {
|
||||
confirmSelectYoutube(url);
|
||||
} else {
|
||||
toast.error(
|
||||
$i18n.t(
|
||||
'Oops! Looks like the URL is invalid. Please double-check and try again.'
|
||||
)
|
||||
);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div class=" font-medium text-black line-clamp-1">
|
||||
{prompt.split(' ')?.at(0)?.substring(1)}
|
||||
</div>
|
||||
|
||||
<div class=" text-xs text-gray-600 line-clamp-1">{$i18n.t('Youtube')}</div>
|
||||
</button>
|
||||
{:else if prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
|
||||
<button
|
||||
class="px-3 py-1.5 rounded-xl w-full text-left bg-gray-100 selected-command-option-button"
|
||||
type="button"
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
<script lang="ts">
|
||||
import { createEventDispatcher } from 'svelte';
|
||||
|
||||
import { generatePrompt } from '$lib/apis/ollama';
|
||||
import { models } from '$lib/stores';
|
||||
import { splitStream } from '$lib/utils';
|
||||
|
@ -7,6 +9,8 @@
|
|||
|
||||
const i18n = getContext('i18n');
|
||||
|
||||
const dispatch = createEventDispatcher();
|
||||
|
||||
export let prompt = '';
|
||||
export let user = null;
|
||||
|
||||
|
@ -17,12 +21,7 @@
|
|||
let filteredModels = [];
|
||||
|
||||
$: filteredModels = $models
|
||||
.filter(
|
||||
(p) =>
|
||||
p.name !== 'hr' &&
|
||||
!p.external &&
|
||||
p.name.includes(prompt.split(' ')?.at(0)?.substring(1) ?? '')
|
||||
)
|
||||
.filter((p) => p.name.includes(prompt.split(' ')?.at(0)?.substring(1) ?? ''))
|
||||
.sort((a, b) => a.name.localeCompare(b.name));
|
||||
|
||||
$: if (prompt) {
|
||||
|
@ -38,6 +37,11 @@
|
|||
};
|
||||
|
||||
const confirmSelect = async (model) => {
|
||||
prompt = '';
|
||||
dispatch('select', model);
|
||||
};
|
||||
|
||||
const confirmSelectCollaborativeChat = async (model) => {
|
||||
// dispatch('select', model);
|
||||
prompt = '';
|
||||
user = JSON.parse(JSON.stringify(model.name));
|
||||
|
@ -127,7 +131,8 @@
|
|||
};
|
||||
</script>
|
||||
|
||||
{#if filteredModels.length > 0}
|
||||
{#if prompt.charAt(0) === '@'}
|
||||
{#if filteredModels.length > 0}
|
||||
<div class="md:px-2 mb-3 text-left w-full absolute bottom-0 left-0 right-0">
|
||||
<div class="flex w-full px-2">
|
||||
<div class=" bg-gray-100 dark:bg-gray-700 w-10 rounded-l-xl text-center">
|
||||
|
@ -163,4 +168,5 @@
|
|||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
|
|
|
@ -1,46 +1,87 @@
|
|||
<script lang="ts">
|
||||
import Bolt from '$lib/components/icons/Bolt.svelte';
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
export let submitPrompt: Function;
|
||||
export let suggestionPrompts = [];
|
||||
|
||||
let prompts = [];
|
||||
|
||||
$: prompts =
|
||||
suggestionPrompts.length <= 4
|
||||
? suggestionPrompts
|
||||
: suggestionPrompts.sort(() => Math.random() - 0.5).slice(0, 4);
|
||||
$: prompts = suggestionPrompts
|
||||
.reduce((acc, current) => [...acc, ...[current]], [])
|
||||
.sort(() => Math.random() - 0.5);
|
||||
// suggestionPrompts.length <= 4
|
||||
// ? suggestionPrompts
|
||||
// : suggestionPrompts.sort(() => Math.random() - 0.5).slice(0, 4);
|
||||
|
||||
onMount(() => {
|
||||
const containerElement = document.getElementById('suggestions-container');
|
||||
|
||||
if (containerElement) {
|
||||
containerElement.addEventListener('wheel', function (event) {
|
||||
if (event.deltaY !== 0) {
|
||||
// If scrolling vertically, prevent default behavior
|
||||
event.preventDefault();
|
||||
// Adjust horizontal scroll position based on vertical scroll
|
||||
containerElement.scrollLeft += event.deltaY;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class=" mb-3 md:p-1 text-left w-full">
|
||||
<div class=" flex flex-wrap-reverse px-2 text-left">
|
||||
{#each prompts as prompt, promptIdx}
|
||||
{#if prompts.length > 0}
|
||||
<div class="mb-2 flex gap-1 text-sm font-medium items-center text-gray-400 dark:text-gray-600">
|
||||
<Bolt />
|
||||
Suggested
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="w-full">
|
||||
<div
|
||||
class="{promptIdx > 1 ? 'hidden sm:inline-flex' : ''} basis-full sm:basis-1/2 p-[5px] px-1"
|
||||
class="relative w-full flex gap-2 snap-x snap-mandatory md:snap-none overflow-x-auto tabs"
|
||||
id="suggestions-container"
|
||||
>
|
||||
{#each prompts as prompt, promptIdx}
|
||||
<div class="snap-center shrink-0">
|
||||
<button
|
||||
class=" flex-1 flex justify-between w-full h-full px-4 py-2.5 bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 rounded-2xl transition group"
|
||||
class="flex flex-col flex-1 shrink-0 w-64 justify-between h-36 p-5 px-6 bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 rounded-3xl transition group"
|
||||
on:click={() => {
|
||||
submitPrompt(prompt.content);
|
||||
}}
|
||||
>
|
||||
<div class="flex flex-col text-left self-center">
|
||||
<div class="flex flex-col text-left">
|
||||
{#if prompt.title && prompt.title[0] !== ''}
|
||||
<div class="text-sm font-medium dark:text-gray-300">{prompt.title[0]}</div>
|
||||
<div class="text-sm text-gray-500 line-clamp-1">{prompt.title[1]}</div>
|
||||
<div
|
||||
class=" font-medium dark:text-gray-300 dark:group-hover:text-gray-200 transition"
|
||||
>
|
||||
{prompt.title[0]}
|
||||
</div>
|
||||
<div class="text-sm text-gray-600 font-normal line-clamp-2">{prompt.title[1]}</div>
|
||||
{:else}
|
||||
<div class=" self-center text-sm font-medium dark:text-gray-300 line-clamp-2">
|
||||
<div
|
||||
class=" self-center text-sm font-medium dark:text-gray-300 dark:group-hover:text-gray-100 transition line-clamp-2"
|
||||
>
|
||||
{prompt.content}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="w-full flex justify-between">
|
||||
<div
|
||||
class="self-center p-1 rounded-lg text-gray-50 group-hover:text-gray-800 dark:text-gray-850 dark:group-hover:text-gray-100 transition"
|
||||
class="text-xs text-gray-400 group-hover:text-gray-500 dark:text-gray-600 dark:group-hover:text-gray-500 transition self-center"
|
||||
>
|
||||
Prompt
|
||||
</div>
|
||||
|
||||
<div
|
||||
class="self-end p-1 rounded-lg text-gray-300 group-hover:text-gray-800 dark:text-gray-700 dark:group-hover:text-gray-100 transition"
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 16 16"
|
||||
fill="currentColor"
|
||||
class="w-4 h-4"
|
||||
class="size-4"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
|
@ -49,8 +90,27 @@
|
|||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
|
||||
<!-- <div class="snap-center shrink-0">
|
||||
<img
|
||||
class="shrink-0 w-80 h-40 rounded-lg shadow-xl bg-white"
|
||||
src="https://images.unsplash.com/photo-1604999565976-8913ad2ddb7c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=320&h=160&q=80"
|
||||
/>
|
||||
</div> -->
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.tabs::-webkit-scrollbar {
|
||||
display: none; /* for Chrome, Safari and Opera */
|
||||
}
|
||||
|
||||
.tabs {
|
||||
-ms-overflow-style: none; /* IE and Edge */
|
||||
scrollbar-width: none; /* Firefox */
|
||||
}
|
||||
</style>
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
import Placeholder from './Messages/Placeholder.svelte';
|
||||
import Spinner from '../common/Spinner.svelte';
|
||||
import { imageGenerations } from '$lib/apis/images';
|
||||
import { copyToClipboard, findWordIndices } from '$lib/utils';
|
||||
|
||||
const i18n = getContext('i18n');
|
||||
|
||||
|
@ -21,6 +22,8 @@
|
|||
export let continueGeneration: Function;
|
||||
export let regenerateResponse: Function;
|
||||
|
||||
export let prompt;
|
||||
export let suggestionPrompts;
|
||||
export let processing = '';
|
||||
export let bottomPadding = false;
|
||||
export let autoScroll;
|
||||
|
@ -42,40 +45,11 @@
|
|||
element.scrollTop = element.scrollHeight;
|
||||
};
|
||||
|
||||
const copyToClipboard = (text) => {
|
||||
if (!navigator.clipboard) {
|
||||
var textArea = document.createElement('textarea');
|
||||
textArea.value = text;
|
||||
|
||||
// Avoid scrolling to bottom
|
||||
textArea.style.top = '0';
|
||||
textArea.style.left = '0';
|
||||
textArea.style.position = 'fixed';
|
||||
|
||||
document.body.appendChild(textArea);
|
||||
textArea.focus();
|
||||
textArea.select();
|
||||
|
||||
try {
|
||||
var successful = document.execCommand('copy');
|
||||
var msg = successful ? 'successful' : 'unsuccessful';
|
||||
console.log('Fallback: Copying text command was ' + msg);
|
||||
} catch (err) {
|
||||
console.error('Fallback: Oops, unable to copy', err);
|
||||
}
|
||||
|
||||
document.body.removeChild(textArea);
|
||||
return;
|
||||
}
|
||||
navigator.clipboard.writeText(text).then(
|
||||
function () {
|
||||
console.log('Async: Copying to clipboard was successful!');
|
||||
const copyToClipboardWithToast = async (text) => {
|
||||
const res = await copyToClipboard(text);
|
||||
if (res) {
|
||||
toast.success($i18n.t('Copying to clipboard was successful!'));
|
||||
},
|
||||
function (err) {
|
||||
console.error('Async: Could not copy text: ', err);
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
const confirmEditMessage = async (messageId, content) => {
|
||||
|
@ -107,12 +81,8 @@
|
|||
await sendPrompt(userPrompt, userMessageId, chatId);
|
||||
};
|
||||
|
||||
const confirmEditResponseMessage = async (messageId, content) => {
|
||||
history.messages[messageId].originalContent = history.messages[messageId].content;
|
||||
history.messages[messageId].content = content;
|
||||
|
||||
const updateChatMessages = async () => {
|
||||
await tick();
|
||||
|
||||
await updateChatById(localStorage.token, chatId, {
|
||||
messages: messages,
|
||||
history: history
|
||||
|
@ -121,15 +91,20 @@
|
|||
await chats.set(await getChatList(localStorage.token));
|
||||
};
|
||||
|
||||
const rateMessage = async (messageId, rating) => {
|
||||
history.messages[messageId].rating = rating;
|
||||
await tick();
|
||||
await updateChatById(localStorage.token, chatId, {
|
||||
messages: messages,
|
||||
history: history
|
||||
});
|
||||
const confirmEditResponseMessage = async (messageId, content) => {
|
||||
history.messages[messageId].originalContent = history.messages[messageId].content;
|
||||
history.messages[messageId].content = content;
|
||||
|
||||
await chats.set(await getChatList(localStorage.token));
|
||||
await updateChatMessages();
|
||||
};
|
||||
|
||||
const rateMessage = async (messageId, rating) => {
|
||||
history.messages[messageId].annotation = {
|
||||
...history.messages[messageId].annotation,
|
||||
rating: rating
|
||||
};
|
||||
|
||||
await updateChatMessages();
|
||||
};
|
||||
|
||||
const showPreviousMessage = async (message) => {
|
||||
|
@ -263,56 +238,58 @@
|
|||
history: history
|
||||
});
|
||||
};
|
||||
|
||||
// const messageDeleteHandler = async (messageId) => {
|
||||
// const message = history.messages[messageId];
|
||||
// const parentId = message.parentId;
|
||||
// const childrenIds = message.childrenIds ?? [];
|
||||
// const grandchildrenIds = [];
|
||||
|
||||
// // Iterate through childrenIds to find grandchildrenIds
|
||||
// for (const childId of childrenIds) {
|
||||
// const childMessage = history.messages[childId];
|
||||
// const grandChildrenIds = childMessage.childrenIds ?? [];
|
||||
|
||||
// for (const grandchildId of grandchildrenIds) {
|
||||
// const childMessage = history.messages[grandchildId];
|
||||
// childMessage.parentId = parentId;
|
||||
// }
|
||||
// grandchildrenIds.push(...grandChildrenIds);
|
||||
// }
|
||||
|
||||
// history.messages[parentId].childrenIds.push(...grandchildrenIds);
|
||||
// history.messages[parentId].childrenIds = history.messages[parentId].childrenIds.filter(
|
||||
// (id) => id !== messageId
|
||||
// );
|
||||
|
||||
// // Select latest message
|
||||
// let currentMessageId = grandchildrenIds.at(-1);
|
||||
// if (currentMessageId) {
|
||||
// let messageChildrenIds = history.messages[currentMessageId].childrenIds;
|
||||
// while (messageChildrenIds.length !== 0) {
|
||||
// currentMessageId = messageChildrenIds.at(-1);
|
||||
// messageChildrenIds = history.messages[currentMessageId].childrenIds;
|
||||
// }
|
||||
// history.currentId = currentMessageId;
|
||||
// }
|
||||
|
||||
// await updateChatById(localStorage.token, chatId, { messages, history });
|
||||
// };
|
||||
</script>
|
||||
|
||||
{#if messages.length == 0}
|
||||
<Placeholder models={selectedModels} modelfiles={selectedModelfiles} />
|
||||
{:else}
|
||||
<div class=" pb-10">
|
||||
<div class="h-full flex mb-16">
|
||||
{#if messages.length == 0}
|
||||
<Placeholder
|
||||
models={selectedModels}
|
||||
modelfiles={selectedModelfiles}
|
||||
{suggestionPrompts}
|
||||
submitPrompt={async (p) => {
|
||||
let text = p;
|
||||
|
||||
if (p.includes('{{CLIPBOARD}}')) {
|
||||
const clipboardText = await navigator.clipboard.readText().catch((err) => {
|
||||
toast.error($i18n.t('Failed to read clipboard contents'));
|
||||
return '{{CLIPBOARD}}';
|
||||
});
|
||||
|
||||
text = p.replaceAll('{{CLIPBOARD}}', clipboardText);
|
||||
}
|
||||
|
||||
prompt = text;
|
||||
|
||||
await tick();
|
||||
|
||||
const chatInputElement = document.getElementById('chat-textarea');
|
||||
if (chatInputElement) {
|
||||
prompt = p;
|
||||
|
||||
chatInputElement.style.height = '';
|
||||
chatInputElement.style.height = Math.min(chatInputElement.scrollHeight, 200) + 'px';
|
||||
chatInputElement.focus();
|
||||
|
||||
const words = findWordIndices(prompt);
|
||||
|
||||
if (words.length > 0) {
|
||||
const word = words.at(0);
|
||||
chatInputElement.setSelectionRange(word?.startIndex, word.endIndex + 1);
|
||||
}
|
||||
}
|
||||
|
||||
await tick();
|
||||
}}
|
||||
/>
|
||||
{:else}
|
||||
<div class="w-full pt-2">
|
||||
{#key chatId}
|
||||
{#each messages as message, messageIdx}
|
||||
<div class=" w-full">
|
||||
<div class=" w-full {messageIdx === messages.length - 1 ? 'pb-28' : ''}">
|
||||
<div
|
||||
class="flex flex-col justify-between px-5 mb-3 {$settings?.fullScreenMode ?? null
|
||||
? 'max-w-full'
|
||||
: 'max-w-3xl'} mx-auto rounded-lg group"
|
||||
: 'max-w-5xl'} mx-auto rounded-lg group"
|
||||
>
|
||||
{#if message.role === 'user'}
|
||||
<UserMessage
|
||||
|
@ -329,7 +306,7 @@
|
|||
{confirmEditMessage}
|
||||
{showPreviousMessage}
|
||||
{showNextMessage}
|
||||
{copyToClipboard}
|
||||
copyToClipboard={copyToClipboardWithToast}
|
||||
/>
|
||||
{:else}
|
||||
<ResponseMessage
|
||||
|
@ -338,11 +315,12 @@
|
|||
siblings={history.messages[message.parentId]?.childrenIds ?? []}
|
||||
isLastMessage={messageIdx + 1 === messages.length}
|
||||
{readOnly}
|
||||
{updateChatMessages}
|
||||
{confirmEditResponseMessage}
|
||||
{showPreviousMessage}
|
||||
{showNextMessage}
|
||||
{rateMessage}
|
||||
{copyToClipboard}
|
||||
copyToClipboard={copyToClipboardWithToast}
|
||||
{continueGeneration}
|
||||
{regenerateResponse}
|
||||
on:save={async (e) => {
|
||||
|
@ -362,8 +340,9 @@
|
|||
{/each}
|
||||
|
||||
{#if bottomPadding}
|
||||
<div class=" mb-10" />
|
||||
<div class=" pb-20" />
|
||||
{/if}
|
||||
{/key}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
|
|
|
@ -31,7 +31,9 @@
|
|||
>
|
||||
</div>
|
||||
|
||||
<pre class=" rounded-b-lg hljs p-4 px-5 overflow-x-auto rounded-t-none"><code
|
||||
<pre
|
||||
class=" hljs p-4 px-5 overflow-x-auto"
|
||||
style="border-top-left-radius: 0px; border-top-right-radius: 0px;"><code
|
||||
class="language-{lang} rounded-t-none whitespace-pre">{@html highlightedCode || code}</code
|
||||
></pre>
|
||||
</div>
|
||||
|
|
|
@ -3,11 +3,19 @@
|
|||
import { user } from '$lib/stores';
|
||||
import { onMount, getContext } from 'svelte';
|
||||
|
||||
import { blur, fade } from 'svelte/transition';
|
||||
|
||||
import Suggestions from '../MessageInput/Suggestions.svelte';
|
||||
|
||||
const i18n = getContext('i18n');
|
||||
|
||||
export let models = [];
|
||||
export let modelfiles = [];
|
||||
|
||||
export let submitPrompt;
|
||||
export let suggestionPrompts;
|
||||
|
||||
let mounted = false;
|
||||
let modelfile = null;
|
||||
let selectedModelIdx = 0;
|
||||
|
||||
|
@ -17,12 +25,16 @@
|
|||
$: if (models.length > 0) {
|
||||
selectedModelIdx = models.length - 1;
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
mounted = true;
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if models.length > 0}
|
||||
<div class="m-auto text-center max-w-md px-2">
|
||||
<div class="flex justify-center mt-8">
|
||||
<div class="flex -space-x-4 mb-1">
|
||||
{#key mounted}
|
||||
<div class="m-auto w-full max-w-6xl px-8 lg:px-24 pb-16">
|
||||
<div class="flex justify-start">
|
||||
<div class="flex -space-x-4 mb-1" in:fade={{ duration: 200 }}>
|
||||
{#each models as model, modelIdx}
|
||||
<button
|
||||
on:click={() => {
|
||||
|
@ -33,15 +45,15 @@
|
|||
<img
|
||||
src={modelfiles[model]?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`}
|
||||
alt="modelfile"
|
||||
class=" size-12 rounded-full border-[1px] border-gray-200 dark:border-none"
|
||||
class=" size-[2.7rem] rounded-full border-[1px] border-gray-200 dark:border-none"
|
||||
draggable="false"
|
||||
/>
|
||||
{:else}
|
||||
<img
|
||||
src={models.length === 1
|
||||
? `${WEBUI_BASE_URL}/static/favicon.png`
|
||||
src={$i18n.language === 'dg-DG'
|
||||
? `/doge.png`
|
||||
: `${WEBUI_BASE_URL}/static/favicon.png`}
|
||||
class=" size-12 rounded-full border-[1px] border-gray-200 dark:border-none"
|
||||
class=" size-[2.7rem] rounded-full border-[1px] border-gray-200 dark:border-none"
|
||||
alt="logo"
|
||||
draggable="false"
|
||||
/>
|
||||
|
@ -50,26 +62,42 @@
|
|||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
<div class=" mt-2 mb-5 text-2xl text-gray-800 dark:text-gray-100 font-semibold">
|
||||
|
||||
<div
|
||||
class=" mt-2 mb-4 text-3xl text-gray-800 dark:text-gray-100 font-semibold text-left flex items-center gap-4"
|
||||
>
|
||||
<div>
|
||||
<div class=" capitalize line-clamp-1" in:fade={{ duration: 200 }}>
|
||||
{#if modelfile}
|
||||
<span class=" capitalize">
|
||||
{modelfile.title}
|
||||
</span>
|
||||
<div class="mt-0.5 text-base font-normal text-gray-600 dark:text-gray-400">
|
||||
{:else}
|
||||
{$i18n.t('Hello, {{name}}', { name: $user.name })}
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div in:fade={{ duration: 200, delay: 200 }}>
|
||||
{#if modelfile}
|
||||
<div class="mt-0.5 text-base font-normal text-gray-500 dark:text-gray-400">
|
||||
{modelfile.desc}
|
||||
</div>
|
||||
{#if modelfile.user}
|
||||
<div class="mt-0.5 text-sm font-normal text-gray-500 dark:text-gray-500">
|
||||
<div class="mt-0.5 text-sm font-normal text-gray-400 dark:text-gray-500">
|
||||
By <a href="https://openwebui.com/m/{modelfile.user.username}"
|
||||
>{modelfile.user.name ? modelfile.user.name : `@${modelfile.user.username}`}</a
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
{:else}
|
||||
<div class=" line-clamp-1">{$i18n.t('Hello, {{name}}', { name: $user.name })}</div>
|
||||
|
||||
<div>{$i18n.t('How can I help you today?')}</div>
|
||||
<div class=" font-medium text-gray-400 dark:text-gray-500">
|
||||
{$i18n.t('How can I help you today?')}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class=" w-full" in:fade={{ duration: 200, delay: 300 }}>
|
||||
<Suggestions {suggestionPrompts} {submitPrompt} />
|
||||
</div>
|
||||
</div>
|
||||
{/key}
|
||||
|
|
129
src/lib/components/chat/Messages/RateComment.svelte
Normal file
129
src/lib/components/chat/Messages/RateComment.svelte
Normal file
|
@ -0,0 +1,129 @@
|
|||
<script lang="ts">
|
||||
import { toast } from 'svelte-sonner';
|
||||
|
||||
import { createEventDispatcher, onMount, getContext } from 'svelte';
|
||||
|
||||
const i18n = getContext('i18n');
|
||||
|
||||
const dispatch = createEventDispatcher();
|
||||
|
||||
export let messageId = null;
|
||||
export let show = false;
|
||||
export let message;
|
||||
|
||||
let LIKE_REASONS = [];
|
||||
let DISLIKE_REASONS = [];
|
||||
|
||||
function loadReasons() {
|
||||
LIKE_REASONS = [
|
||||
$i18n.t('Accurate information'),
|
||||
$i18n.t('Followed instructions perfectly'),
|
||||
$i18n.t('Showcased creativity'),
|
||||
$i18n.t('Positive attitude'),
|
||||
$i18n.t('Attention to detail'),
|
||||
$i18n.t('Thorough explanation'),
|
||||
$i18n.t('Other')
|
||||
];
|
||||
|
||||
DISLIKE_REASONS = [
|
||||
$i18n.t("Don't like the style"),
|
||||
$i18n.t('Not factually correct'),
|
||||
$i18n.t("Didn't fully follow instructions"),
|
||||
$i18n.t("Refused when it shouldn't have"),
|
||||
$i18n.t('Being lazy'),
|
||||
$i18n.t('Other')
|
||||
];
|
||||
}
|
||||
|
||||
let reasons = [];
|
||||
let selectedReason = null;
|
||||
let comment = '';
|
||||
|
||||
$: if (message.annotation.rating === 1) {
|
||||
reasons = LIKE_REASONS;
|
||||
} else if (message.annotation.rating === -1) {
|
||||
reasons = DISLIKE_REASONS;
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
selectedReason = message.annotation.reason;
|
||||
comment = message.annotation.comment;
|
||||
loadReasons();
|
||||
});
|
||||
|
||||
const submitHandler = () => {
|
||||
console.log('submitHandler');
|
||||
|
||||
message.annotation.reason = selectedReason;
|
||||
message.annotation.comment = comment;
|
||||
|
||||
dispatch('submit');
|
||||
|
||||
toast.success($i18n.t('Thanks for your feedback!'));
|
||||
show = false;
|
||||
};
|
||||
</script>
|
||||
|
||||
<div
|
||||
class=" my-2.5 rounded-xl px-4 py-3 border dark:border-gray-850"
|
||||
id="message-feedback-{messageId}"
|
||||
>
|
||||
<div class="flex justify-between items-center">
|
||||
<div class=" text-sm">{$i18n.t('Tell us more:')}</div>
|
||||
|
||||
<button
|
||||
on:click={() => {
|
||||
show = false;
|
||||
}}
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke-width="1.5"
|
||||
stroke="currentColor"
|
||||
class="size-4"
|
||||
>
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M6 18 18 6M6 6l12 12" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{#if reasons.length > 0}
|
||||
<div class="flex flex-wrap gap-2 text-sm mt-2.5">
|
||||
{#each reasons as reason}
|
||||
<button
|
||||
class="px-3.5 py-1 border dark:border-gray-850 hover:bg-gray-100 dark:hover:bg-gray-850 {selectedReason ===
|
||||
reason
|
||||
? 'bg-gray-200 dark:bg-gray-800'
|
||||
: ''} transition rounded-lg"
|
||||
on:click={() => {
|
||||
selectedReason = reason;
|
||||
}}
|
||||
>
|
||||
{reason}
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="mt-2">
|
||||
<textarea
|
||||
bind:value={comment}
|
||||
class="w-full text-sm px-1 py-2 bg-transparent outline-none resize-none rounded-xl"
|
||||
placeholder={$i18n.t('Feel free to add specific details')}
|
||||
rows="2"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="mt-2 flex justify-end">
|
||||
<button
|
||||
class=" bg-emerald-700 text-white text-sm font-medium rounded-lg px-3.5 py-1.5"
|
||||
on:click={() => {
|
||||
submitHandler();
|
||||
}}
|
||||
>
|
||||
Submit
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
|
@ -15,9 +15,10 @@
|
|||
const dispatch = createEventDispatcher();
|
||||
|
||||
import { config, settings } from '$lib/stores';
|
||||
import { synthesizeOpenAISpeech } from '$lib/apis/openai';
|
||||
import { synthesizeOpenAISpeech } from '$lib/apis/audio';
|
||||
import { imageGenerations } from '$lib/apis/images';
|
||||
import {
|
||||
approximateToHumanReadable,
|
||||
extractSentences,
|
||||
revertSanitizedResponseContent,
|
||||
sanitizeResponseContent
|
||||
|
@ -30,6 +31,7 @@
|
|||
import Image from '$lib/components/common/Image.svelte';
|
||||
import { WEBUI_BASE_URL } from '$lib/constants';
|
||||
import Tooltip from '$lib/components/common/Tooltip.svelte';
|
||||
import RateComment from './RateComment.svelte';
|
||||
|
||||
export let modelfiles = [];
|
||||
export let message;
|
||||
|
@ -39,6 +41,7 @@
|
|||
|
||||
export let readOnly = false;
|
||||
|
||||
export let updateChatMessages: Function;
|
||||
export let confirmEditResponseMessage: Function;
|
||||
export let showPreviousMessage: Function;
|
||||
export let showNextMessage: Function;
|
||||
|
@ -60,6 +63,8 @@
|
|||
let loadingSpeech = false;
|
||||
let generatingImage = false;
|
||||
|
||||
let showRateComment = false;
|
||||
|
||||
$: tokens = marked.lexer(sanitizeResponseContent(message.content));
|
||||
|
||||
const renderer = new marked.Renderer();
|
||||
|
@ -118,16 +123,21 @@
|
|||
eval_count: ${message.info.eval_count ?? 'N/A'}<br/>
|
||||
eval_duration: ${
|
||||
Math.round(((message.info.eval_duration ?? 0) / 1000000) * 100) / 100 ?? 'N/A'
|
||||
}ms</span>`,
|
||||
}ms<br/>
|
||||
approximate_total: ${approximateToHumanReadable(
|
||||
message.info.total_duration
|
||||
)}</span>`,
|
||||
allowHTML: true
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const renderLatex = () => {
|
||||
let chatMessageElements = document.getElementsByClassName('chat-assistant');
|
||||
// let lastChatMessageElement = chatMessageElements[chatMessageElements.length - 1];
|
||||
let chatMessageElements = document
|
||||
.getElementById(`message-${message.id}`)
|
||||
?.getElementsByClassName('chat-assistant');
|
||||
|
||||
if (chatMessageElements) {
|
||||
for (const element of chatMessageElements) {
|
||||
auto_render(element, {
|
||||
// customised options
|
||||
|
@ -143,6 +153,7 @@
|
|||
throwOnError: false
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const playAudio = (idx) => {
|
||||
|
@ -168,10 +179,12 @@
|
|||
|
||||
const toggleSpeakMessage = async () => {
|
||||
if (speaking) {
|
||||
try {
|
||||
speechSynthesis.cancel();
|
||||
|
||||
sentencesAudio[speakingIdx].pause();
|
||||
sentencesAudio[speakingIdx].currentTime = 0;
|
||||
} catch {}
|
||||
|
||||
speaking = null;
|
||||
speakingIdx = null;
|
||||
|
@ -213,6 +226,10 @@
|
|||
sentence
|
||||
).catch((error) => {
|
||||
toast.error(error);
|
||||
|
||||
speaking = null;
|
||||
loadingSpeech = false;
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
|
@ -222,7 +239,6 @@
|
|||
const audio = new Audio(blobUrl);
|
||||
sentencesAudio[idx] = audio;
|
||||
loadingSpeech = false;
|
||||
|
||||
lastPlayedAudioPromise = lastPlayedAudioPromise.then(() => playAudio(idx));
|
||||
}
|
||||
}
|
||||
|
@ -309,9 +325,10 @@
|
|||
</script>
|
||||
|
||||
{#key message.id}
|
||||
<div class=" flex w-full message-{message.id}">
|
||||
<div class=" flex w-full message-{message.id}" id="message-{message.id}">
|
||||
<ProfileImage
|
||||
src={modelfiles[message.model]?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`}
|
||||
src={modelfiles[message.model]?.imageUrl ??
|
||||
($i18n.language === 'dg-DG' ? `/doge.png` : `${WEBUI_BASE_URL}/static/favicon.png`)}
|
||||
/>
|
||||
|
||||
<div class="w-full overflow-hidden">
|
||||
|
@ -363,7 +380,7 @@
|
|||
|
||||
<div class=" mt-2 mb-1 flex justify-center space-x-2 text-sm font-medium">
|
||||
<button
|
||||
class="px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded-lg"
|
||||
class="px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
|
||||
on:click={() => {
|
||||
editMessageConfirmHandler();
|
||||
}}
|
||||
|
@ -478,7 +495,7 @@
|
|||
{/if}
|
||||
|
||||
{#if !readOnly}
|
||||
<Tooltip content="Edit" placement="bottom">
|
||||
<Tooltip content={$i18n.t('Edit')} placement="bottom">
|
||||
<button
|
||||
class="{isLastMessage
|
||||
? 'visible'
|
||||
|
@ -505,7 +522,7 @@
|
|||
</Tooltip>
|
||||
{/if}
|
||||
|
||||
<Tooltip content="Copy" placement="bottom">
|
||||
<Tooltip content={$i18n.t('Copy')} placement="bottom">
|
||||
<button
|
||||
class="{isLastMessage
|
||||
? 'visible'
|
||||
|
@ -532,15 +549,23 @@
|
|||
</Tooltip>
|
||||
|
||||
{#if !readOnly}
|
||||
<Tooltip content="Good Response" placement="bottom">
|
||||
<Tooltip content={$i18n.t('Good Response')} placement="bottom">
|
||||
<button
|
||||
class="{isLastMessage
|
||||
? 'visible'
|
||||
: 'invisible group-hover:visible'} p-1 rounded {message.rating === 1
|
||||
: 'invisible group-hover:visible'} p-1 rounded {message?.annotation
|
||||
?.rating === 1
|
||||
? 'bg-gray-100 dark:bg-gray-800'
|
||||
: ''} dark:hover:text-white hover:text-black transition"
|
||||
on:click={() => {
|
||||
rateMessage(message.id, 1);
|
||||
showRateComment = true;
|
||||
|
||||
window.setTimeout(() => {
|
||||
document
|
||||
.getElementById(`message-feedback-${message.id}`)
|
||||
?.scrollIntoView();
|
||||
}, 0);
|
||||
}}
|
||||
>
|
||||
<svg
|
||||
|
@ -559,15 +584,22 @@
|
|||
</button>
|
||||
</Tooltip>
|
||||
|
||||
<Tooltip content="Bad Response" placement="bottom">
|
||||
<Tooltip content={$i18n.t('Bad Response')} placement="bottom">
|
||||
<button
|
||||
class="{isLastMessage
|
||||
? 'visible'
|
||||
: 'invisible group-hover:visible'} p-1 rounded {message.rating === -1
|
||||
: 'invisible group-hover:visible'} p-1 rounded {message?.annotation
|
||||
?.rating === -1
|
||||
? 'bg-gray-100 dark:bg-gray-800'
|
||||
: ''} dark:hover:text-white hover:text-black transition"
|
||||
on:click={() => {
|
||||
rateMessage(message.id, -1);
|
||||
showRateComment = true;
|
||||
window.setTimeout(() => {
|
||||
document
|
||||
.getElementById(`message-feedback-${message.id}`)
|
||||
?.scrollIntoView();
|
||||
}, 0);
|
||||
}}
|
||||
>
|
||||
<svg
|
||||
|
@ -587,7 +619,7 @@
|
|||
</Tooltip>
|
||||
{/if}
|
||||
|
||||
<Tooltip content="Read Aloud" placement="bottom">
|
||||
<Tooltip content={$i18n.t('Read Aloud')} placement="bottom">
|
||||
<button
|
||||
id="speak-button-{message.id}"
|
||||
class="{isLastMessage
|
||||
|
@ -736,7 +768,7 @@
|
|||
{/if}
|
||||
|
||||
{#if message.info}
|
||||
<Tooltip content="Generation Info" placement="bottom">
|
||||
<Tooltip content={$i18n.t('Generation Info')} placement="bottom">
|
||||
<button
|
||||
class=" {isLastMessage
|
||||
? 'visible'
|
||||
|
@ -765,7 +797,7 @@
|
|||
{/if}
|
||||
|
||||
{#if isLastMessage && !readOnly}
|
||||
<Tooltip content="Continue Response" placement="bottom">
|
||||
<Tooltip content={$i18n.t('Continue Response')} placement="bottom">
|
||||
<button
|
||||
type="button"
|
||||
class="{isLastMessage
|
||||
|
@ -797,7 +829,7 @@
|
|||
</button>
|
||||
</Tooltip>
|
||||
|
||||
<Tooltip content="Regenerate" placement="bottom">
|
||||
<Tooltip content={$i18n.t('Regenerate')} placement="bottom">
|
||||
<button
|
||||
type="button"
|
||||
class="{isLastMessage
|
||||
|
@ -824,6 +856,17 @@
|
|||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if showRateComment}
|
||||
<RateComment
|
||||
messageId={message.id}
|
||||
bind:show={showRateComment}
|
||||
bind:message
|
||||
on:submit={() => {
|
||||
updateChatMessages();
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
|
|
@ -176,11 +176,24 @@
|
|||
e.target.style.height = '';
|
||||
e.target.style.height = `${e.target.scrollHeight}px`;
|
||||
}}
|
||||
on:keydown={(e) => {
|
||||
if (e.key === 'Escape') {
|
||||
document.getElementById('close-edit-message-button')?.click();
|
||||
}
|
||||
|
||||
const isCmdOrCtrlPressed = e.metaKey || e.ctrlKey;
|
||||
const isEnterPressed = e.key === 'Enter';
|
||||
|
||||
if (isCmdOrCtrlPressed && isEnterPressed) {
|
||||
document.getElementById('save-edit-message-button')?.click();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
<div class=" mt-2 mb-1 flex justify-center space-x-2 text-sm font-medium">
|
||||
<button
|
||||
class="px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded-lg"
|
||||
id="save-edit-message-button"
|
||||
class="px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
|
||||
on:click={() => {
|
||||
editMessageConfirmHandler();
|
||||
}}
|
||||
|
@ -189,6 +202,7 @@
|
|||
</button>
|
||||
|
||||
<button
|
||||
id="close-edit-message-button"
|
||||
class=" px-4 py-2 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 text-gray-700 dark:text-gray-100 transition outline outline-1 outline-gray-200 dark:outline-gray-600 rounded-lg"
|
||||
on:click={() => {
|
||||
cancelEditMessage();
|
||||
|
@ -252,7 +266,7 @@
|
|||
{/if}
|
||||
|
||||
{#if !readOnly}
|
||||
<Tooltip content="Edit" placement="bottom">
|
||||
<Tooltip content={$i18n.t('Edit')} placement="bottom">
|
||||
<button
|
||||
class="invisible group-hover:visible p-1 rounded dark:hover:text-white hover:text-black transition edit-user-message-button"
|
||||
on:click={() => {
|
||||
|
@ -277,7 +291,7 @@
|
|||
</Tooltip>
|
||||
{/if}
|
||||
|
||||
<Tooltip content="Copy" placement="bottom">
|
||||
<Tooltip content={$i18n.t('Copy')} placement="bottom">
|
||||
<button
|
||||
class="invisible group-hover:visible p-1 rounded dark:hover:text-white hover:text-black transition"
|
||||
on:click={() => {
|
||||
|
@ -302,7 +316,7 @@
|
|||
</Tooltip>
|
||||
|
||||
{#if !isFirstMessage && !readOnly}
|
||||
<Tooltip content="Delete" placement="bottom">
|
||||
<Tooltip content={$i18n.t('Delete')} placement="bottom">
|
||||
<button
|
||||
class="invisible group-hover:visible p-1 rounded dark:hover:text-white hover:text-black transition"
|
||||
on:click={() => {
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
export let selectedModels = [''];
|
||||
export let disabled = false;
|
||||
|
||||
export let showSetDefault = true;
|
||||
|
||||
const saveDefaultModel = async () => {
|
||||
const hasEmptyModel = selectedModels.filter((it) => it === '');
|
||||
if (hasEmptyModel.length) {
|
||||
|
@ -38,9 +40,9 @@
|
|||
|
||||
<div class="flex flex-col mt-0.5 w-full">
|
||||
{#each selectedModels as selectedModel, selectedModelIdx}
|
||||
<div class="flex w-full">
|
||||
<div class="flex w-full max-w-fit">
|
||||
<div class="overflow-hidden w-full">
|
||||
<div class="mr-0.5 max-w-full">
|
||||
<div class="mr-1 max-w-full">
|
||||
<Selector
|
||||
placeholder={$i18n.t('Select a model')}
|
||||
items={$models
|
||||
|
@ -57,7 +59,7 @@
|
|||
|
||||
{#if selectedModelIdx === 0}
|
||||
<div class=" self-center mr-2 disabled:text-gray-600 disabled:hover:text-gray-600">
|
||||
<Tooltip content="Add Model">
|
||||
<Tooltip content={$i18n.t('Add Model')}>
|
||||
<button
|
||||
class=" "
|
||||
{disabled}
|
||||
|
@ -69,9 +71,9 @@
|
|||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke-width="1.5"
|
||||
stroke-width="2"
|
||||
stroke="currentColor"
|
||||
class="w-4 h-4"
|
||||
class="size-3.5"
|
||||
>
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M12 6v12m6-6H6" />
|
||||
</svg>
|
||||
|
@ -92,9 +94,9 @@
|
|||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke-width="1.5"
|
||||
stroke-width="2"
|
||||
stroke="currentColor"
|
||||
class="w-4 h-4"
|
||||
class="size-3.5"
|
||||
>
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M19.5 12h-15" />
|
||||
</svg>
|
||||
|
@ -106,6 +108,8 @@
|
|||
{/each}
|
||||
</div>
|
||||
|
||||
<div class="text-left mt-0.5 ml-1 text-[0.7rem] text-gray-500">
|
||||
{#if showSetDefault}
|
||||
<div class="text-left mt-0.5 ml-1 text-[0.7rem] text-gray-500">
|
||||
<button on:click={saveDefaultModel}> {$i18n.t('Set as default')}</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
<script lang="ts">
|
||||
import { Select } from 'bits-ui';
|
||||
import { DropdownMenu } from 'bits-ui';
|
||||
|
||||
import { flyAndScale } from '$lib/utils/transitions';
|
||||
import { createEventDispatcher, onMount, getContext, tick } from 'svelte';
|
||||
|
@ -21,10 +21,17 @@
|
|||
export let value = '';
|
||||
export let placeholder = 'Select a model';
|
||||
export let searchEnabled = true;
|
||||
export let searchPlaceholder = 'Search a model';
|
||||
export let searchPlaceholder = $i18n.t('Search a model');
|
||||
|
||||
export let items = [{ value: 'mango', label: 'Mango' }];
|
||||
|
||||
export let className = ' w-[32rem]';
|
||||
|
||||
let show = false;
|
||||
|
||||
let selectedModel = '';
|
||||
$: selectedModel = items.find((item) => item.value === value) ?? '';
|
||||
|
||||
let searchValue = '';
|
||||
let ollamaVersion = null;
|
||||
|
||||
|
@ -33,7 +40,7 @@
|
|||
: items;
|
||||
|
||||
const pullModelHandler = async () => {
|
||||
const sanitizedModelTag = searchValue.trim();
|
||||
const sanitizedModelTag = searchValue.trim().replace(/^ollama\s+(run|pull)\s+/, '');
|
||||
|
||||
console.log($MODEL_DOWNLOAD_POOL);
|
||||
if ($MODEL_DOWNLOAD_POOL[sanitizedModelTag]) {
|
||||
|
@ -175,27 +182,29 @@
|
|||
};
|
||||
</script>
|
||||
|
||||
<Select.Root
|
||||
{items}
|
||||
<DropdownMenu.Root
|
||||
bind:open={show}
|
||||
onOpenChange={async () => {
|
||||
searchValue = '';
|
||||
window.setTimeout(() => document.getElementById('model-search-input')?.focus(), 0);
|
||||
}}
|
||||
selected={items.find((item) => item.value === value) ?? ''}
|
||||
onSelectedChange={(selectedItem) => {
|
||||
value = selectedItem.value;
|
||||
}}
|
||||
>
|
||||
<Select.Trigger class="relative w-full" aria-label={placeholder}>
|
||||
<Select.Value
|
||||
class="flex text-left px-0.5 outline-none bg-transparent truncate text-lg font-semibold placeholder-gray-400 focus:outline-none"
|
||||
<DropdownMenu.Trigger class="relative w-full" aria-label={placeholder}>
|
||||
<div
|
||||
class="flex w-full text-left px-0.5 outline-none bg-transparent truncate text-lg font-semibold placeholder-gray-400 focus:outline-none"
|
||||
>
|
||||
{#if selectedModel}
|
||||
{selectedModel.label}
|
||||
{:else}
|
||||
{placeholder}
|
||||
/>
|
||||
<ChevronDown className="absolute end-2 top-1/2 -translate-y-[45%] size-3.5" strokeWidth="2.5" />
|
||||
</Select.Trigger>
|
||||
<Select.Content
|
||||
class=" z-40 w-full rounded-lg bg-white dark:bg-gray-900 dark:text-white shadow-lg border border-gray-300/30 dark:border-gray-700/50 outline-none"
|
||||
{/if}
|
||||
<ChevronDown className=" self-center ml-2 size-3" strokeWidth="2.5" />
|
||||
</div>
|
||||
</DropdownMenu.Trigger>
|
||||
<DropdownMenu.Content
|
||||
class=" z-40 {className} max-w-[calc(100vw-1rem)] justify-start rounded-lg bg-white dark:bg-gray-900 dark:text-white shadow-lg border border-gray-300/30 dark:border-gray-700/50 outline-none "
|
||||
transition={flyAndScale}
|
||||
side={'bottom-start'}
|
||||
sideOffset={4}
|
||||
>
|
||||
<slot>
|
||||
|
@ -208,18 +217,23 @@
|
|||
bind:value={searchValue}
|
||||
class="w-full text-sm bg-transparent outline-none"
|
||||
placeholder={searchPlaceholder}
|
||||
autocomplete="off"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<hr class="border-gray-100 dark:border-gray-800" />
|
||||
{/if}
|
||||
|
||||
<div class="px-3 my-2 max-h-72 overflow-y-auto">
|
||||
<div class="px-3 my-2 max-h-72 overflow-y-auto scrollbar-none">
|
||||
{#each filteredItems as item}
|
||||
<Select.Item
|
||||
class="flex w-full font-medium line-clamp-1 select-none items-center rounded-button py-2 pl-3 pr-1.5 text-sm text-gray-700 dark:text-gray-100 outline-none transition-all duration-75 hover:bg-gray-100 dark:hover:bg-gray-850 rounded-lg cursor-pointer data-[highlighted]:bg-muted"
|
||||
value={item.value}
|
||||
label={item.label}
|
||||
<button
|
||||
aria-label="model-item"
|
||||
class="flex w-full text-left font-medium line-clamp-1 select-none items-center rounded-button py-2 pl-3 pr-1.5 text-sm text-gray-700 dark:text-gray-100 outline-none transition-all duration-75 hover:bg-gray-100 dark:hover:bg-gray-850 rounded-lg cursor-pointer data-[highlighted]:bg-muted"
|
||||
on:click={() => {
|
||||
value = item.value;
|
||||
|
||||
show = false;
|
||||
}}
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="line-clamp-1">
|
||||
|
@ -287,7 +301,7 @@
|
|||
<Check />
|
||||
</div>
|
||||
{/if}
|
||||
</Select.Item>
|
||||
</button>
|
||||
{:else}
|
||||
<div>
|
||||
<div class="block px-3 py-2 text-sm text-gray-700 dark:text-gray-100">
|
||||
|
@ -384,6 +398,20 @@
|
|||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<div class="hidden w-[42rem]" />
|
||||
<div class="hidden w-[32rem]" />
|
||||
</slot>
|
||||
</Select.Content>
|
||||
</Select.Root>
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
|
||||
<style>
|
||||
.scrollbar-none:active::-webkit-scrollbar-thumb,
|
||||
.scrollbar-none:focus::-webkit-scrollbar-thumb,
|
||||
.scrollbar-none:hover::-webkit-scrollbar-thumb {
|
||||
visibility: visible;
|
||||
}
|
||||
.scrollbar-none::-webkit-scrollbar-thumb {
|
||||
visibility: hidden;
|
||||
}
|
||||
</style>
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue