refac: litellm model name validation

This commit is contained in:
Timothy J. Baek 2024-04-21 18:25:53 -05:00
parent 5997774ab8
commit 4651db8c09

View file

@ -12,7 +12,7 @@ import json
import time import time
import requests import requests
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from typing import Optional, List from typing import Optional, List
from utils.utils import get_verified_user, get_current_user, get_admin_user from utils.utils import get_verified_user, get_current_user, get_admin_user
@ -25,6 +25,7 @@ log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, DATA_DIR from config import MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, DATA_DIR
from litellm.utils import get_llm_provider
import asyncio import asyncio
import subprocess import subprocess
@ -165,6 +166,8 @@ class LiteLLMConfigForm(BaseModel):
model_list: Optional[List[dict]] = None model_list: Optional[List[dict]] = None
router_settings: Optional[dict] = None router_settings: Optional[dict] = None
model_config = ConfigDict(protected_namespaces=())
@app.post("/config/update") @app.post("/config/update")
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)): async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
@ -236,13 +239,15 @@ class AddLiteLLMModelForm(BaseModel):
model_name: str model_name: str
litellm_params: dict litellm_params: dict
model_config = ConfigDict(protected_namespaces=())
@app.post("/model/new") @app.post("/model/new")
async def add_model_to_config( async def add_model_to_config(
form_data: AddLiteLLMModelForm, user=Depends(get_admin_user) form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
): ):
# TODO: Validate model form try:
get_llm_provider(model=form_data.model_name)
app.state.CONFIG["model_list"].append(form_data.model_dump()) app.state.CONFIG["model_list"].append(form_data.model_dump())
with open(LITELLM_CONFIG_DIR, "w") as file: with open(LITELLM_CONFIG_DIR, "w") as file:
@ -251,6 +256,11 @@ async def add_model_to_config(
await restart_litellm() await restart_litellm()
return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)} return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
class DeleteLiteLLMModelForm(BaseModel): class DeleteLiteLLMModelForm(BaseModel):