feat: add support for using postgres for the backend DB

This commit is contained in:
Jun Siang Cheah 2024-04-24 18:10:18 +01:00
parent f8f9f27ae8
commit e91a49c455
15 changed files with 329 additions and 18 deletions

View file

@ -21,6 +21,8 @@ from utils.utils import get_verified_user, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS, ENV from config import SRC_LOG_LEVELS, ENV
from constants import MESSAGES from constants import MESSAGES
import os
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"]) log.setLevel(SRC_LOG_LEVELS["LITELLM"])
@ -62,6 +64,13 @@ app.state.CONFIG = litellm_config
# Global variable to store the subprocess reference # Global variable to store the subprocess reference
background_process = None background_process = None
CONFLICT_ENV_VARS = [
# Uvicorn uses PORT, so LiteLLM might use it as well
"PORT",
# LiteLLM uses DATABASE_URL for Prisma connections
"DATABASE_URL",
]
async def run_background_process(command): async def run_background_process(command):
global background_process global background_process
@ -70,9 +79,11 @@ async def run_background_process(command):
try: try:
# Log the command to be executed # Log the command to be executed
log.info(f"Executing command: {command}") log.info(f"Executing command: {command}")
# Filter environment variables known to conflict with litellm
env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS}
# Execute the command and create a subprocess # Execute the command and create a subprocess
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*command, stdout=subprocess.PIPE, stderr=subprocess.PIPE *command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
) )
background_process = process background_process = process
log.info("Subprocess started successfully.") log.info("Subprocess started successfully.")

View file

@ -1,6 +1,7 @@
from peewee import * from peewee import *
from peewee_migrate import Router from peewee_migrate import Router
from config import SRC_LOG_LEVELS, DATA_DIR from playhouse.db_url import connect
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL
import os import os
import logging import logging
@ -11,12 +12,12 @@ log.setLevel(SRC_LOG_LEVELS["DB"])
if os.path.exists(f"{DATA_DIR}/ollama.db"): if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file # Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
log.info("File renamed successfully.") log.info("Database migrated from Ollama-WebUI successfully.")
else: else:
pass pass
DB = connect(DATABASE_URL)
DB = SqliteDatabase(f"{DATA_DIR}/webui.db") log.info(f"Connected to a {DB.__class__.__name__} database.")
router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log) router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log)
router.run() router.run()
DB.connect(reuse_if_open=True) DB.connect(reuse_if_open=True)

View file

@ -37,6 +37,18 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
# We perform different migrations for SQLite and other databases
# This is because SQLite is very loose with enforcing its schema, and trying to migrate other databases like SQLite
# will require per-database SQL queries.
# Instead, we assume that because external DB support was added at a later date, it is safe to assume a newer base
# schema instead of trying to migrate from an older schema.
if isinstance(database, pw.SqliteDatabase):
migrate_sqlite(migrator, database, fake=fake)
else:
migrate_external(migrator, database, fake=fake)
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model @migrator.create_model
class Auth(pw.Model): class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True) id = pw.CharField(max_length=255, unique=True)
@ -129,6 +141,99 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
table_name = "user" table_name = "user"
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True)
email = pw.CharField(max_length=255)
password = pw.TextField()
active = pw.BooleanField()
class Meta:
table_name = "auth"
@migrator.create_model
class Chat(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.TextField()
chat = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chat"
@migrator.create_model
class ChatIdTag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
tag_name = pw.CharField(max_length=255)
chat_id = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chatidtag"
@migrator.create_model
class Document(pw.Model):
id = pw.AutoField()
collection_name = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255, unique=True)
title = pw.TextField()
filename = pw.TextField()
content = pw.TextField(null=True)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "document"
@migrator.create_model
class Modelfile(pw.Model):
id = pw.AutoField()
tag_name = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
modelfile = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "modelfile"
@migrator.create_model
class Prompt(pw.Model):
id = pw.AutoField()
command = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.TextField()
content = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "prompt"
@migrator.create_model
class Tag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
data = pw.TextField(null=True)
class Meta:
table_name = "tag"
@migrator.create_model
class User(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
email = pw.CharField(max_length=255)
role = pw.CharField(max_length=255)
profile_image_url = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "user"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""

View file

@ -37,6 +37,13 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
if isinstance(database, pw.SqliteDatabase):
migrate_sqlite(migrator, database, fake=fake)
else:
migrate_external(migrator, database, fake=fake)
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table # Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields( migrator.add_fields(
"chat", "chat",
@ -60,9 +67,40 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
) )
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields(
"chat",
created_at=pw.BigIntegerField(null=True), # Allow null for transition
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
)
# Populate the new fields from an existing 'timestamp' field
migrator.sql(
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
)
# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("chat", "timestamp")
# Update the fields to be not null now that they are populated
migrator.change_fields(
"chat",
created_at=pw.BigIntegerField(null=False),
updated_at=pw.BigIntegerField(null=False),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
if isinstance(database, pw.SqliteDatabase):
rollback_sqlite(migrator, database, fake=fake)
else:
rollback_external(migrator, database, fake=fake)
def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition # Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True)) migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True))
@ -75,3 +113,18 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
# Finally, alter the timestamp field to not allow nulls if that was the original setting # Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False)) migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False))
def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True))
# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql("UPDATE chat SET timestamp = created_at")
# Remove the created_at and updated_at fields
migrator.remove_fields("chat", "created_at", "updated_at")
# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False))

View file

@ -0,0 +1,130 @@
"""Peewee migrations -- 006_migrate_timestamps_and_charfields.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"document",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"modelfile",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"prompt",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"user",
timestamp=pw.BigIntegerField(),
)
# Alter the tables with varchar to text where necessary
migrator.change_fields(
"auth",
password=pw.TextField(),
)
migrator.change_fields(
"chat",
title=pw.TextField(),
)
migrator.change_fields(
"document",
title=pw.TextField(),
filename=pw.TextField(),
)
migrator.change_fields(
"prompt",
title=pw.TextField(),
)
migrator.change_fields(
"user",
profile_image_url=pw.TextField(),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
if isinstance(database, pw.SqliteDatabase):
# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
timestamp=pw.DateField(),
)
migrator.change_fields(
"document",
timestamp=pw.DateField(),
)
migrator.change_fields(
"modelfile",
timestamp=pw.DateField(),
)
migrator.change_fields(
"prompt",
timestamp=pw.DateField(),
)
migrator.change_fields(
"user",
timestamp=pw.DateField(),
)
migrator.change_fields(
"auth",
password=pw.CharField(max_length=255),
)
migrator.change_fields(
"chat",
title=pw.CharField(),
)
migrator.change_fields(
"document",
title=pw.CharField(),
filename=pw.CharField(),
)
migrator.change_fields(
"prompt",
title=pw.CharField(),
)
migrator.change_fields(
"user",
profile_image_url=pw.CharField(),
)

View file

@ -23,7 +23,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Auth(Model): class Auth(Model):
id = CharField(unique=True) id = CharField(unique=True)
email = CharField() email = CharField()
password = CharField() password = TextField()
active = BooleanField() active = BooleanField()
class Meta: class Meta:

View file

@ -17,11 +17,11 @@ from apps.web.internal.db import DB
class Chat(Model): class Chat(Model):
id = CharField(unique=True) id = CharField(unique=True)
user_id = CharField() user_id = CharField()
title = CharField() title = TextField()
chat = TextField() # Save Chat JSON as Text chat = TextField() # Save Chat JSON as Text
created_at = DateTimeField() created_at = BigIntegerField()
updated_at = DateTimeField() updated_at = BigIntegerField()
share_id = CharField(null=True, unique=True) share_id = CharField(null=True, unique=True)
archived = BooleanField(default=False) archived = BooleanField(default=False)

View file

@ -25,11 +25,11 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Document(Model): class Document(Model):
collection_name = CharField(unique=True) collection_name = CharField(unique=True)
name = CharField(unique=True) name = CharField(unique=True)
title = CharField() title = TextField()
filename = CharField() filename = TextField()
content = TextField(null=True) content = TextField(null=True)
user_id = CharField() user_id = CharField()
timestamp = DateField() timestamp = BigIntegerField()
class Meta: class Meta:
database = DB database = DB

View file

@ -20,7 +20,7 @@ class Modelfile(Model):
tag_name = CharField(unique=True) tag_name = CharField(unique=True)
user_id = CharField() user_id = CharField()
modelfile = TextField() modelfile = TextField()
timestamp = DateField() timestamp = BigIntegerField()
class Meta: class Meta:
database = DB database = DB

View file

@ -19,9 +19,9 @@ import json
class Prompt(Model): class Prompt(Model):
command = CharField(unique=True) command = CharField(unique=True)
user_id = CharField() user_id = CharField()
title = CharField() title = TextField()
content = TextField() content = TextField()
timestamp = DateField() timestamp = BigIntegerField()
class Meta: class Meta:
database = DB database = DB

View file

@ -35,7 +35,7 @@ class ChatIdTag(Model):
tag_name = CharField() tag_name = CharField()
chat_id = CharField() chat_id = CharField()
user_id = CharField() user_id = CharField()
timestamp = DateField() timestamp = BigIntegerField()
class Meta: class Meta:
database = DB database = DB

View file

@ -18,8 +18,8 @@ class User(Model):
name = CharField() name = CharField()
email = CharField() email = CharField()
role = CharField() role = CharField()
profile_image_url = CharField() profile_image_url = TextField()
timestamp = DateField() timestamp = BigIntegerField()
api_key = CharField(null=True, unique=True) api_key = CharField(null=True, unique=True)
class Meta: class Meta:

View file

@ -1,3 +1,5 @@
import logging
from fastapi import Request from fastapi import Request
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status

View file

@ -534,3 +534,10 @@ LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365"))
if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535: if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535:
raise ValueError("Invalid port number for LITELLM_PROXY_PORT") raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1") LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")
####################################
# Database
####################################
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")

View file

@ -15,6 +15,8 @@ requests
aiohttp aiohttp
peewee peewee
peewee-migrate peewee-migrate
psycopg2-binary
pymysql
bcrypt bcrypt
litellm==1.35.17 litellm==1.35.17