Merge pull request #1791 from cheahjs/feat/external-db-support

feat: add support for using postgres for the backend DB
This commit is contained in:
Timothy Jaeryang Baek 2024-04-27 12:48:56 -07:00 committed by GitHub
commit c1d85f8a6f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 350 additions and 24 deletions

View file

@ -21,6 +21,8 @@ from utils.utils import get_verified_user, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS, ENV from config import SRC_LOG_LEVELS, ENV
from constants import MESSAGES from constants import MESSAGES
import os
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"]) log.setLevel(SRC_LOG_LEVELS["LITELLM"])
@ -62,6 +64,13 @@ app.state.CONFIG = litellm_config
# Global variable to store the subprocess reference # Global variable to store the subprocess reference
background_process = None background_process = None
CONFLICT_ENV_VARS = [
# Uvicorn uses PORT, so LiteLLM might use it as well
"PORT",
# LiteLLM uses DATABASE_URL for Prisma connections
"DATABASE_URL",
]
async def run_background_process(command): async def run_background_process(command):
global background_process global background_process
@ -70,9 +79,11 @@ async def run_background_process(command):
try: try:
# Log the command to be executed # Log the command to be executed
log.info(f"Executing command: {command}") log.info(f"Executing command: {command}")
# Filter environment variables known to conflict with litellm
env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS}
# Execute the command and create a subprocess # Execute the command and create a subprocess
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*command, stdout=subprocess.PIPE, stderr=subprocess.PIPE *command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
) )
background_process = process background_process = process
log.info("Subprocess started successfully.") log.info("Subprocess started successfully.")

View file

@ -1,6 +1,7 @@
from peewee import * from peewee import *
from peewee_migrate import Router from peewee_migrate import Router
from config import SRC_LOG_LEVELS, DATA_DIR from playhouse.db_url import connect
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL
import os import os
import logging import logging
@ -11,12 +12,12 @@ log.setLevel(SRC_LOG_LEVELS["DB"])
if os.path.exists(f"{DATA_DIR}/ollama.db"): if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file # Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
log.info("File renamed successfully.") log.info("Database migrated from Ollama-WebUI successfully.")
else: else:
pass pass
DB = connect(DATABASE_URL)
DB = SqliteDatabase(f"{DATA_DIR}/webui.db") log.info(f"Connected to a {DB.__class__.__name__} database.")
router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log) router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log)
router.run() router.run()
DB.connect(reuse_if_open=True) DB.connect(reuse_if_open=True)

View file

@ -37,6 +37,18 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
# We perform different migrations for SQLite and other databases
# This is because SQLite is very loose with enforcing its schema, and trying to migrate other databases like SQLite
# will require per-database SQL queries.
# Instead, we assume that because external DB support was added at a later date, it is safe to assume a newer base
# schema instead of trying to migrate from an older schema.
if isinstance(database, pw.SqliteDatabase):
migrate_sqlite(migrator, database, fake=fake)
else:
migrate_external(migrator, database, fake=fake)
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model @migrator.create_model
class Auth(pw.Model): class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True) id = pw.CharField(max_length=255, unique=True)
@ -129,6 +141,99 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
table_name = "user" table_name = "user"
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True)
email = pw.CharField(max_length=255)
password = pw.TextField()
active = pw.BooleanField()
class Meta:
table_name = "auth"
@migrator.create_model
class Chat(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.TextField()
chat = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chat"
@migrator.create_model
class ChatIdTag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
tag_name = pw.CharField(max_length=255)
chat_id = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chatidtag"
@migrator.create_model
class Document(pw.Model):
id = pw.AutoField()
collection_name = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255, unique=True)
title = pw.TextField()
filename = pw.TextField()
content = pw.TextField(null=True)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "document"
@migrator.create_model
class Modelfile(pw.Model):
id = pw.AutoField()
tag_name = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
modelfile = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "modelfile"
@migrator.create_model
class Prompt(pw.Model):
id = pw.AutoField()
command = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.TextField()
content = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "prompt"
@migrator.create_model
class Tag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
data = pw.TextField(null=True)
class Meta:
table_name = "tag"
@migrator.create_model
class User(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
email = pw.CharField(max_length=255)
role = pw.CharField(max_length=255)
profile_image_url = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "user"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""

View file

@ -37,6 +37,13 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
if isinstance(database, pw.SqliteDatabase):
migrate_sqlite(migrator, database, fake=fake)
else:
migrate_external(migrator, database, fake=fake)
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table # Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields( migrator.add_fields(
"chat", "chat",
@ -60,9 +67,40 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
) )
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields(
"chat",
created_at=pw.BigIntegerField(null=True), # Allow null for transition
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
)
# Populate the new fields from an existing 'timestamp' field
migrator.sql(
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
)
# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("chat", "timestamp")
# Update the fields to be not null now that they are populated
migrator.change_fields(
"chat",
created_at=pw.BigIntegerField(null=False),
updated_at=pw.BigIntegerField(null=False),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
if isinstance(database, pw.SqliteDatabase):
rollback_sqlite(migrator, database, fake=fake)
else:
rollback_external(migrator, database, fake=fake)
def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition # Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True)) migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True))
@ -75,3 +113,18 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
# Finally, alter the timestamp field to not allow nulls if that was the original setting # Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False)) migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False))
def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True))
# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql("UPDATE chat SET timestamp = created_at")
# Remove the created_at and updated_at fields
migrator.remove_fields("chat", "created_at", "updated_at")
# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False))

View file

@ -0,0 +1,130 @@
"""Peewee migrations -- 006_migrate_timestamps_and_charfields.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"document",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"modelfile",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"prompt",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"user",
timestamp=pw.BigIntegerField(),
)
# Alter the tables with varchar to text where necessary
migrator.change_fields(
"auth",
password=pw.TextField(),
)
migrator.change_fields(
"chat",
title=pw.TextField(),
)
migrator.change_fields(
"document",
title=pw.TextField(),
filename=pw.TextField(),
)
migrator.change_fields(
"prompt",
title=pw.TextField(),
)
migrator.change_fields(
"user",
profile_image_url=pw.TextField(),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
if isinstance(database, pw.SqliteDatabase):
# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
timestamp=pw.DateField(),
)
migrator.change_fields(
"document",
timestamp=pw.DateField(),
)
migrator.change_fields(
"modelfile",
timestamp=pw.DateField(),
)
migrator.change_fields(
"prompt",
timestamp=pw.DateField(),
)
migrator.change_fields(
"user",
timestamp=pw.DateField(),
)
migrator.change_fields(
"auth",
password=pw.CharField(max_length=255),
)
migrator.change_fields(
"chat",
title=pw.CharField(),
)
migrator.change_fields(
"document",
title=pw.CharField(),
filename=pw.CharField(),
)
migrator.change_fields(
"prompt",
title=pw.CharField(),
)
migrator.change_fields(
"user",
profile_image_url=pw.CharField(),
)

View file

@ -23,7 +23,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Auth(Model): class Auth(Model):
id = CharField(unique=True) id = CharField(unique=True)
email = CharField() email = CharField()
password = CharField() password = TextField()
active = BooleanField() active = BooleanField()
class Meta: class Meta:

View file

@ -17,11 +17,11 @@ from apps.web.internal.db import DB
class Chat(Model): class Chat(Model):
id = CharField(unique=True) id = CharField(unique=True)
user_id = CharField() user_id = CharField()
title = CharField() title = TextField()
chat = TextField() # Save Chat JSON as Text chat = TextField() # Save Chat JSON as Text
created_at = DateTimeField() created_at = BigIntegerField()
updated_at = DateTimeField() updated_at = BigIntegerField()
share_id = CharField(null=True, unique=True) share_id = CharField(null=True, unique=True)
archived = BooleanField(default=False) archived = BooleanField(default=False)

View file

@ -25,11 +25,11 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Document(Model): class Document(Model):
collection_name = CharField(unique=True) collection_name = CharField(unique=True)
name = CharField(unique=True) name = CharField(unique=True)
title = CharField() title = TextField()
filename = CharField() filename = TextField()
content = TextField(null=True) content = TextField(null=True)
user_id = CharField() user_id = CharField()
timestamp = DateField() timestamp = BigIntegerField()
class Meta: class Meta:
database = DB database = DB

View file

@ -20,7 +20,7 @@ class Modelfile(Model):
tag_name = CharField(unique=True) tag_name = CharField(unique=True)
user_id = CharField() user_id = CharField()
modelfile = TextField() modelfile = TextField()
timestamp = DateField() timestamp = BigIntegerField()
class Meta: class Meta:
database = DB database = DB

View file

@ -19,9 +19,9 @@ import json
class Prompt(Model): class Prompt(Model):
command = CharField(unique=True) command = CharField(unique=True)
user_id = CharField() user_id = CharField()
title = CharField() title = TextField()
content = TextField() content = TextField()
timestamp = DateField() timestamp = BigIntegerField()
class Meta: class Meta:
database = DB database = DB

View file

@ -35,7 +35,7 @@ class ChatIdTag(Model):
tag_name = CharField() tag_name = CharField()
chat_id = CharField() chat_id = CharField()
user_id = CharField() user_id = CharField()
timestamp = DateField() timestamp = BigIntegerField()
class Meta: class Meta:
database = DB database = DB

View file

@ -18,8 +18,8 @@ class User(Model):
name = CharField() name = CharField()
email = CharField() email = CharField()
role = CharField() role = CharField()
profile_image_url = CharField() profile_image_url = TextField()
timestamp = DateField() timestamp = BigIntegerField()
api_key = CharField(null=True, unique=True) api_key = CharField(null=True, unique=True)
class Meta: class Meta:

View file

@ -1,3 +1,5 @@
import logging
from fastapi import Request from fastapi import Request
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status

View file

@ -1,5 +1,6 @@
from fastapi import APIRouter, UploadFile, File, Response from fastapi import APIRouter, UploadFile, File, Response
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from peewee import SqliteDatabase
from starlette.responses import StreamingResponse, FileResponse from starlette.responses import StreamingResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
@ -7,7 +8,7 @@ from pydantic import BaseModel
from fpdf import FPDF from fpdf import FPDF
import markdown import markdown
from apps.web.internal.db import DB
from utils.utils import get_admin_user from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url from utils.misc import calculate_sha256, get_gravatar_url
@ -96,8 +97,13 @@ async def download_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
if not isinstance(DB, SqliteDatabase):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
)
return FileResponse( return FileResponse(
f"{DATA_DIR}/webui.db", DB.database,
media_type="application/octet-stream", media_type="application/octet-stream",
filename="webui.db", filename="webui.db",
) )

View file

@ -534,3 +534,10 @@ LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365"))
if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535: if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535:
raise ValueError("Invalid port number for LITELLM_PROXY_PORT") raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1") LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")
####################################
# Database
####################################
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")

View file

@ -69,3 +69,5 @@ class ERROR_MESSAGES(str, Enum):
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance." CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."

View file

@ -15,6 +15,8 @@ requests
aiohttp aiohttp
peewee peewee
peewee-migrate peewee-migrate
psycopg2-binary
pymysql
bcrypt bcrypt
litellm==1.35.17 litellm==1.35.17

View file

@ -83,9 +83,9 @@ export const downloadDatabase = async (token: string) => {
Authorization: `Bearer ${token}` Authorization: `Bearer ${token}`
} }
}) })
.then((response) => { .then(async (response) => {
if (!response.ok) { if (!response.ok) {
throw new Error('Network response was not ok'); throw await response.json();
} }
return response.blob(); return response.blob();
}) })
@ -100,7 +100,11 @@ export const downloadDatabase = async (token: string) => {
}) })
.catch((err) => { .catch((err) => {
console.log(err); console.log(err);
error = err; error = err.detail;
return null; return null;
}); });
if (error) {
throw error;
}
}; };

View file

@ -2,6 +2,7 @@
import { downloadDatabase } from '$lib/apis/utils'; import { downloadDatabase } from '$lib/apis/utils';
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { config } from '$lib/stores'; import { config } from '$lib/stores';
import { toast } from 'svelte-sonner';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -32,7 +33,9 @@
on:click={() => { on:click={() => {
// exportAllUserChats(); // exportAllUserChats();
downloadDatabase(localStorage.token); downloadDatabase(localStorage.token).catch((error) => {
toast.error(error);
});
}} }}
> >
<div class=" self-center mr-3"> <div class=" self-center mr-3">