forked from open-webui/open-webui
		
	main #2
					 3 changed files with 31 additions and 23 deletions
				
			
		|  | @ -98,13 +98,14 @@ def merge_models_lists(model_lists): | ||||||
|     merged_models = {} |     merged_models = {} | ||||||
| 
 | 
 | ||||||
|     for idx, model_list in enumerate(model_lists): |     for idx, model_list in enumerate(model_lists): | ||||||
|         for model in model_list: |         if model_list is not None: | ||||||
|             digest = model["digest"] |             for model in model_list: | ||||||
|             if digest not in merged_models: |                 digest = model["digest"] | ||||||
|                 model["urls"] = [idx] |                 if digest not in merged_models: | ||||||
|                 merged_models[digest] = model |                     model["urls"] = [idx] | ||||||
|             else: |                     merged_models[digest] = model | ||||||
|                 merged_models[digest]["urls"].append(idx) |                 else: | ||||||
|  |                     merged_models[digest]["urls"].append(idx) | ||||||
| 
 | 
 | ||||||
|     return list(merged_models.values()) |     return list(merged_models.values()) | ||||||
| 
 | 
 | ||||||
|  | @ -116,11 +117,10 @@ async def get_all_models(): | ||||||
|     print("get_all_models") |     print("get_all_models") | ||||||
|     tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] |     tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] | ||||||
|     responses = await asyncio.gather(*tasks) |     responses = await asyncio.gather(*tasks) | ||||||
|     responses = list(filter(lambda x: x is not None, responses)) |  | ||||||
| 
 | 
 | ||||||
|     models = { |     models = { | ||||||
|         "models": merge_models_lists( |         "models": merge_models_lists( | ||||||
|             map(lambda response: response["models"], responses) |             map(lambda response: response["models"] if response else None, responses) | ||||||
|         ) |         ) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -168,14 +168,15 @@ def merge_models_lists(model_lists): | ||||||
|     merged_list = [] |     merged_list = [] | ||||||
| 
 | 
 | ||||||
|     for idx, models in enumerate(model_lists): |     for idx, models in enumerate(model_lists): | ||||||
|         merged_list.extend( |         if models is not None and "error" not in models: | ||||||
|             [ |             merged_list.extend( | ||||||
|                 {**model, "urlIdx": idx} |                 [ | ||||||
|                 for model in models |                     {**model, "urlIdx": idx} | ||||||
|                 if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx] |                     for model in models | ||||||
|                 or "gpt" in model["id"] |                     if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx] | ||||||
|             ] |                     or "gpt" in model["id"] | ||||||
|         ) |                 ] | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|     return merged_list |     return merged_list | ||||||
| 
 | 
 | ||||||
|  | @ -190,15 +191,20 @@ async def get_all_models(): | ||||||
|             fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) |             fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) | ||||||
|             for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) |             for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) | ||||||
|         ] |         ] | ||||||
|  | 
 | ||||||
|         responses = await asyncio.gather(*tasks) |         responses = await asyncio.gather(*tasks) | ||||||
|         responses = list( |  | ||||||
|             filter(lambda x: x is not None and "error" not in x, responses) |  | ||||||
|         ) |  | ||||||
|         models = { |         models = { | ||||||
|             "data": merge_models_lists( |             "data": merge_models_lists( | ||||||
|                 list(map(lambda response: response["data"], responses)) |                 list( | ||||||
|  |                     map( | ||||||
|  |                         lambda response: response["data"] if response else None, | ||||||
|  |                         responses, | ||||||
|  |                     ) | ||||||
|  |                 ) | ||||||
|             ) |             ) | ||||||
|         } |         } | ||||||
|  | 
 | ||||||
|  |         print(models) | ||||||
|         app.state.MODELS = {model["id"]: model for model in models["data"]} |         app.state.MODELS = {model["id"]: model for model in models["data"]} | ||||||
| 
 | 
 | ||||||
|         return models |         return models | ||||||
|  |  | ||||||
|  | @ -250,8 +250,10 @@ OPENAI_API_BASE_URLS = ( | ||||||
|     OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL |     OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URLS.split(";")] | OPENAI_API_BASE_URLS = [ | ||||||
| 
 |     url.strip() if url != "" else "https://api.openai.com/v1" | ||||||
|  |     for url in OPENAI_API_BASE_URLS.split(";") | ||||||
|  | ] | ||||||
| 
 | 
 | ||||||
| #################################### | #################################### | ||||||
| # WEBUI | # WEBUI | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue