fix: multiple openai issue

This commit is contained in:
Timothy J. Baek 2024-03-18 01:11:48 -07:00
parent e414b9ea6d
commit 1bfcd801b7
3 changed files with 31 additions and 23 deletions

View file

@ -98,13 +98,14 @@ def merge_models_lists(model_lists):
merged_models = {} merged_models = {}
for idx, model_list in enumerate(model_lists): for idx, model_list in enumerate(model_lists):
for model in model_list: if model_list is not None:
digest = model["digest"] for model in model_list:
if digest not in merged_models: digest = model["digest"]
model["urls"] = [idx] if digest not in merged_models:
merged_models[digest] = model model["urls"] = [idx]
else: merged_models[digest] = model
merged_models[digest]["urls"].append(idx) else:
merged_models[digest]["urls"].append(idx)
return list(merged_models.values()) return list(merged_models.values())
@ -116,11 +117,10 @@ async def get_all_models():
print("get_all_models") print("get_all_models")
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses))
models = { models = {
"models": merge_models_lists( "models": merge_models_lists(
map(lambda response: response["models"], responses) map(lambda response: response["models"] if response else None, responses)
) )
} }

View file

@ -168,14 +168,15 @@ def merge_models_lists(model_lists):
merged_list = [] merged_list = []
for idx, models in enumerate(model_lists): for idx, models in enumerate(model_lists):
merged_list.extend( if models is not None and "error" not in models:
[ merged_list.extend(
{**model, "urlIdx": idx} [
for model in models {**model, "urlIdx": idx}
if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx] for model in models
or "gpt" in model["id"] if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
] or "gpt" in model["id"]
) ]
)
return merged_list return merged_list
@ -190,15 +191,20 @@ async def get_all_models():
fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
responses = list(
filter(lambda x: x is not None and "error" not in x, responses)
)
models = { models = {
"data": merge_models_lists( "data": merge_models_lists(
list(map(lambda response: response["data"], responses)) list(
map(
lambda response: response["data"] if response else None,
responses,
)
)
) )
} }
print(models)
app.state.MODELS = {model["id"]: model for model in models["data"]} app.state.MODELS = {model["id"]: model for model in models["data"]}
return models return models

View file

@ -250,8 +250,10 @@ OPENAI_API_BASE_URLS = (
OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL
) )
OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URLS.split(";")] OPENAI_API_BASE_URLS = [
url.strip() if url != "" else "https://api.openai.com/v1"
for url in OPENAI_API_BASE_URLS.split(";")
]
#################################### ####################################
# WEBUI # WEBUI