forked from open-webui/open-webui
		
	
							parent
							
								
									b2dd2f191d
								
							
						
					
					
						commit
						8d34324d12
					
				
					 2 changed files with 126 additions and 63 deletions
				
			
		|  | @ -3,7 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware | |||
| from fastapi.responses import StreamingResponse | ||||
| from fastapi.concurrency import run_in_threadpool | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| from pydantic import BaseModel, ConfigDict | ||||
| 
 | ||||
| import random | ||||
| import requests | ||||
|  | @ -684,7 +684,7 @@ class GenerateChatCompletionForm(BaseModel): | |||
| 
 | ||||
| @app.post("/api/chat") | ||||
| @app.post("/api/chat/{url_idx}") | ||||
| async def generate_completion( | ||||
| async def generate_chat_completion( | ||||
|     form_data: GenerateChatCompletionForm, | ||||
|     url_idx: Optional[int] = None, | ||||
|     user=Depends(get_current_user), | ||||
|  | @ -765,6 +765,104 @@ async def generate_completion( | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| # TODO: we should update this part once Ollama supports other types | ||||
| class OpenAIChatMessage(BaseModel): | ||||
|     role: str | ||||
|     content: str | ||||
| 
 | ||||
|     model_config = ConfigDict(extra="allow") | ||||
| 
 | ||||
| 
 | ||||
| class OpenAIChatCompletionForm(BaseModel): | ||||
|     model: str | ||||
|     messages: List[OpenAIChatMessage] | ||||
| 
 | ||||
|     model_config = ConfigDict(extra="allow") | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/v1/chat/completions") | ||||
| @app.post("/v1/chat/completions/{url_idx}") | ||||
| async def generate_openai_chat_completion( | ||||
|     form_data: OpenAIChatCompletionForm, | ||||
|     url_idx: Optional[int] = None, | ||||
|     user=Depends(get_current_user), | ||||
| ): | ||||
| 
 | ||||
|     if url_idx == None: | ||||
|         if form_data.model in app.state.MODELS: | ||||
|             url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) | ||||
|         else: | ||||
|             raise HTTPException( | ||||
|                 status_code=400, | ||||
|                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), | ||||
|             ) | ||||
| 
 | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
| 
 | ||||
|     r = None | ||||
| 
 | ||||
|     def get_request(): | ||||
|         nonlocal form_data | ||||
|         nonlocal r | ||||
| 
 | ||||
|         request_id = str(uuid.uuid4()) | ||||
|         try: | ||||
|             REQUEST_POOL.append(request_id) | ||||
| 
 | ||||
|             def stream_content(): | ||||
|                 try: | ||||
|                     if form_data.stream: | ||||
|                         yield json.dumps( | ||||
|                             {"request_id": request_id, "done": False} | ||||
|                         ) + "\n" | ||||
| 
 | ||||
|                     for chunk in r.iter_content(chunk_size=8192): | ||||
|                         if request_id in REQUEST_POOL: | ||||
|                             yield chunk | ||||
|                         else: | ||||
|                             print("User: canceled request") | ||||
|                             break | ||||
|                 finally: | ||||
|                     if hasattr(r, "close"): | ||||
|                         r.close() | ||||
|                         if request_id in REQUEST_POOL: | ||||
|                             REQUEST_POOL.remove(request_id) | ||||
| 
 | ||||
|             r = requests.request( | ||||
|                 method="POST", | ||||
|                 url=f"{url}/v1/chat/completions", | ||||
|                 data=form_data.model_dump_json(exclude_none=True), | ||||
|                 stream=True, | ||||
|             ) | ||||
| 
 | ||||
|             r.raise_for_status() | ||||
| 
 | ||||
|             return StreamingResponse( | ||||
|                 stream_content(), | ||||
|                 status_code=r.status_code, | ||||
|                 headers=dict(r.headers), | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             raise e | ||||
| 
 | ||||
|     try: | ||||
|         return await run_in_threadpool(get_request) | ||||
|     except Exception as e: | ||||
|         error_detail = "Open WebUI: Server Connection Error" | ||||
|         if r is not None: | ||||
|             try: | ||||
|                 res = r.json() | ||||
|                 if "error" in res: | ||||
|                     error_detail = f"Ollama: {res['error']}" | ||||
|             except: | ||||
|                 error_detail = f"Ollama: {e}" | ||||
| 
 | ||||
|         raise HTTPException( | ||||
|             status_code=r.status_code if r else 500, | ||||
|             detail=error_detail, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | ||||
| async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||
|     url = app.state.OLLAMA_BASE_URLS[0] | ||||
|  |  | |||
|  | @ -26,7 +26,7 @@ | |||
| 	let selectedModelId = ''; | ||||
| 
 | ||||
| 	let loading = false; | ||||
| 	let currentRequestId; | ||||
| 	let currentRequestId = null; | ||||
| 	let stopResponseFlag = false; | ||||
| 
 | ||||
| 	let messagesContainerElement: HTMLDivElement; | ||||
|  | @ -92,6 +92,10 @@ | |||
| 			while (true) { | ||||
| 				const { value, done } = await reader.read(); | ||||
| 				if (done || stopResponseFlag) { | ||||
| 					if (stopResponseFlag) { | ||||
| 						await cancelChatCompletion(localStorage.token, currentRequestId); | ||||
| 					} | ||||
| 
 | ||||
| 					currentRequestId = null; | ||||
| 					break; | ||||
| 				} | ||||
|  | @ -108,10 +112,14 @@ | |||
| 								let data = JSON.parse(line.replace(/^data: /, '')); | ||||
| 								console.log(data); | ||||
| 
 | ||||
| 								if ('request_id' in data) { | ||||
| 									currentRequestId = data.request_id; | ||||
| 								} else { | ||||
| 									text += data.choices[0].delta.content ?? ''; | ||||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} catch (error) { | ||||
| 					console.log(error); | ||||
| 				} | ||||
|  | @ -146,16 +154,6 @@ | |||
| 				: `${OLLAMA_API_BASE_URL}/v1` | ||||
| 		); | ||||
| 
 | ||||
| 		// const [res, controller] = await generateChatCompletion(localStorage.token, { | ||||
| 		// 	model: selectedModelId, | ||||
| 		// 	messages: [ | ||||
| 		// 		{ | ||||
| 		// 			role: 'assistant', | ||||
| 		// 			content: text | ||||
| 		// 		} | ||||
| 		// 	] | ||||
| 		// }); | ||||
| 
 | ||||
| 		let responseMessage; | ||||
| 		if (messages.at(-1)?.role === 'assistant') { | ||||
| 			responseMessage = messages.at(-1); | ||||
|  | @ -180,6 +178,11 @@ | |||
| 			while (true) { | ||||
| 				const { value, done } = await reader.read(); | ||||
| 				if (done || stopResponseFlag) { | ||||
| 					if (stopResponseFlag) { | ||||
| 						await cancelChatCompletion(localStorage.token, currentRequestId); | ||||
| 					} | ||||
| 
 | ||||
| 					currentRequestId = null; | ||||
| 					break; | ||||
| 				} | ||||
| 
 | ||||
|  | @ -196,6 +199,9 @@ | |||
| 								let data = JSON.parse(line.replace(/^data: /, '')); | ||||
| 								console.log(data); | ||||
| 
 | ||||
| 								if ('request_id' in data) { | ||||
| 									currentRequestId = data.request_id; | ||||
| 								} else { | ||||
| 									if (responseMessage.content == '' && data.choices[0].delta.content == '\n') { | ||||
| 										continue; | ||||
| 									} else { | ||||
|  | @ -211,54 +217,13 @@ | |||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} catch (error) { | ||||
| 					console.log(error); | ||||
| 				} | ||||
| 
 | ||||
| 				scrollToBottom(); | ||||
| 			} | ||||
| 
 | ||||
| 			// while (true) { | ||||
| 			// 	const { value, done } = await reader.read(); | ||||
| 			// 	if (done || stopResponseFlag) { | ||||
| 			// 		if (stopResponseFlag) { | ||||
| 			// 			await cancelChatCompletion(localStorage.token, currentRequestId); | ||||
| 			// 		} | ||||
| 
 | ||||
| 			// 		currentRequestId = null; | ||||
| 			// 		break; | ||||
| 			// 	} | ||||
| 
 | ||||
| 			// 	try { | ||||
| 			// 		let lines = value.split('\n'); | ||||
| 
 | ||||
| 			// 		for (const line of lines) { | ||||
| 			// 			if (line !== '') { | ||||
| 			// 				console.log(line); | ||||
| 			// 				let data = JSON.parse(line); | ||||
| 
 | ||||
| 			// 				if ('detail' in data) { | ||||
| 			// 					throw data; | ||||
| 			// 				} | ||||
| 
 | ||||
| 			// 				if ('id' in data) { | ||||
| 			// 					console.log(data); | ||||
| 			// 					currentRequestId = data.id; | ||||
| 			// 				} else { | ||||
| 			// 					if (data.done == false) { | ||||
| 			// 						text += data.message.content; | ||||
| 			// 					} else { | ||||
| 			// 						console.log('done'); | ||||
| 			// 					} | ||||
| 			// 				} | ||||
| 			// 			} | ||||
| 			// 		} | ||||
| 			// 	} catch (error) { | ||||
| 			// 		console.log(error); | ||||
| 			// 	} | ||||
| 
 | ||||
| 			// 	scrollToBottom(); | ||||
| 			// } | ||||
| 		} | ||||
| 	}; | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek