forked from open-webui/open-webui
		
	
							parent
							
								
									b2dd2f191d
								
							
						
					
					
						commit
						8d34324d12
					
				
					 2 changed files with 126 additions and 63 deletions
				
			
		|  | @ -3,7 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware | ||||||
| from fastapi.responses import StreamingResponse | from fastapi.responses import StreamingResponse | ||||||
| from fastapi.concurrency import run_in_threadpool | from fastapi.concurrency import run_in_threadpool | ||||||
| 
 | 
 | ||||||
| from pydantic import BaseModel | from pydantic import BaseModel, ConfigDict | ||||||
| 
 | 
 | ||||||
| import random | import random | ||||||
| import requests | import requests | ||||||
|  | @ -684,7 +684,7 @@ class GenerateChatCompletionForm(BaseModel): | ||||||
| 
 | 
 | ||||||
| @app.post("/api/chat") | @app.post("/api/chat") | ||||||
| @app.post("/api/chat/{url_idx}") | @app.post("/api/chat/{url_idx}") | ||||||
| async def generate_completion( | async def generate_chat_completion( | ||||||
|     form_data: GenerateChatCompletionForm, |     form_data: GenerateChatCompletionForm, | ||||||
|     url_idx: Optional[int] = None, |     url_idx: Optional[int] = None, | ||||||
|     user=Depends(get_current_user), |     user=Depends(get_current_user), | ||||||
|  | @ -765,6 +765,104 @@ async def generate_completion( | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | # TODO: we should update this part once Ollama supports other types | ||||||
|  | class OpenAIChatMessage(BaseModel): | ||||||
|  |     role: str | ||||||
|  |     content: str | ||||||
|  | 
 | ||||||
|  |     model_config = ConfigDict(extra="allow") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class OpenAIChatCompletionForm(BaseModel): | ||||||
|  |     model: str | ||||||
|  |     messages: List[OpenAIChatMessage] | ||||||
|  | 
 | ||||||
|  |     model_config = ConfigDict(extra="allow") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @app.post("/v1/chat/completions") | ||||||
|  | @app.post("/v1/chat/completions/{url_idx}") | ||||||
|  | async def generate_openai_chat_completion( | ||||||
|  |     form_data: OpenAIChatCompletionForm, | ||||||
|  |     url_idx: Optional[int] = None, | ||||||
|  |     user=Depends(get_current_user), | ||||||
|  | ): | ||||||
|  | 
 | ||||||
|  |     if url_idx == None: | ||||||
|  |         if form_data.model in app.state.MODELS: | ||||||
|  |             url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) | ||||||
|  |         else: | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=400, | ||||||
|  |                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||||
|  | 
 | ||||||
|  |     r = None | ||||||
|  | 
 | ||||||
|  |     def get_request(): | ||||||
|  |         nonlocal form_data | ||||||
|  |         nonlocal r | ||||||
|  | 
 | ||||||
|  |         request_id = str(uuid.uuid4()) | ||||||
|  |         try: | ||||||
|  |             REQUEST_POOL.append(request_id) | ||||||
|  | 
 | ||||||
|  |             def stream_content(): | ||||||
|  |                 try: | ||||||
|  |                     if form_data.stream: | ||||||
|  |                         yield json.dumps( | ||||||
|  |                             {"request_id": request_id, "done": False} | ||||||
|  |                         ) + "\n" | ||||||
|  | 
 | ||||||
|  |                     for chunk in r.iter_content(chunk_size=8192): | ||||||
|  |                         if request_id in REQUEST_POOL: | ||||||
|  |                             yield chunk | ||||||
|  |                         else: | ||||||
|  |                             print("User: canceled request") | ||||||
|  |                             break | ||||||
|  |                 finally: | ||||||
|  |                     if hasattr(r, "close"): | ||||||
|  |                         r.close() | ||||||
|  |                         if request_id in REQUEST_POOL: | ||||||
|  |                             REQUEST_POOL.remove(request_id) | ||||||
|  | 
 | ||||||
|  |             r = requests.request( | ||||||
|  |                 method="POST", | ||||||
|  |                 url=f"{url}/v1/chat/completions", | ||||||
|  |                 data=form_data.model_dump_json(exclude_none=True), | ||||||
|  |                 stream=True, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             r.raise_for_status() | ||||||
|  | 
 | ||||||
|  |             return StreamingResponse( | ||||||
|  |                 stream_content(), | ||||||
|  |                 status_code=r.status_code, | ||||||
|  |                 headers=dict(r.headers), | ||||||
|  |             ) | ||||||
|  |         except Exception as e: | ||||||
|  |             raise e | ||||||
|  | 
 | ||||||
|  |     try: | ||||||
|  |         return await run_in_threadpool(get_request) | ||||||
|  |     except Exception as e: | ||||||
|  |         error_detail = "Open WebUI: Server Connection Error" | ||||||
|  |         if r is not None: | ||||||
|  |             try: | ||||||
|  |                 res = r.json() | ||||||
|  |                 if "error" in res: | ||||||
|  |                     error_detail = f"Ollama: {res['error']}" | ||||||
|  |             except: | ||||||
|  |                 error_detail = f"Ollama: {e}" | ||||||
|  | 
 | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=r.status_code if r else 500, | ||||||
|  |             detail=error_detail, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | ||||||
| async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): | async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||||
|     url = app.state.OLLAMA_BASE_URLS[0] |     url = app.state.OLLAMA_BASE_URLS[0] | ||||||
|  |  | ||||||
|  | @ -26,7 +26,7 @@ | ||||||
| 	let selectedModelId = ''; | 	let selectedModelId = ''; | ||||||
| 
 | 
 | ||||||
| 	let loading = false; | 	let loading = false; | ||||||
| 	let currentRequestId; | 	let currentRequestId = null; | ||||||
| 	let stopResponseFlag = false; | 	let stopResponseFlag = false; | ||||||
| 
 | 
 | ||||||
| 	let messagesContainerElement: HTMLDivElement; | 	let messagesContainerElement: HTMLDivElement; | ||||||
|  | @ -92,6 +92,10 @@ | ||||||
| 			while (true) { | 			while (true) { | ||||||
| 				const { value, done } = await reader.read(); | 				const { value, done } = await reader.read(); | ||||||
| 				if (done || stopResponseFlag) { | 				if (done || stopResponseFlag) { | ||||||
|  | 					if (stopResponseFlag) { | ||||||
|  | 						await cancelChatCompletion(localStorage.token, currentRequestId); | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
| 					currentRequestId = null; | 					currentRequestId = null; | ||||||
| 					break; | 					break; | ||||||
| 				} | 				} | ||||||
|  | @ -108,7 +112,11 @@ | ||||||
| 								let data = JSON.parse(line.replace(/^data: /, '')); | 								let data = JSON.parse(line.replace(/^data: /, '')); | ||||||
| 								console.log(data); | 								console.log(data); | ||||||
| 
 | 
 | ||||||
| 								text += data.choices[0].delta.content ?? ''; | 								if ('request_id' in data) { | ||||||
|  | 									currentRequestId = data.request_id; | ||||||
|  | 								} else { | ||||||
|  | 									text += data.choices[0].delta.content ?? ''; | ||||||
|  | 								} | ||||||
| 							} | 							} | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
|  | @ -146,16 +154,6 @@ | ||||||
| 				: `${OLLAMA_API_BASE_URL}/v1` | 				: `${OLLAMA_API_BASE_URL}/v1` | ||||||
| 		); | 		); | ||||||
| 
 | 
 | ||||||
| 		// const [res, controller] = await generateChatCompletion(localStorage.token, { |  | ||||||
| 		// 	model: selectedModelId, |  | ||||||
| 		// 	messages: [ |  | ||||||
| 		// 		{ |  | ||||||
| 		// 			role: 'assistant', |  | ||||||
| 		// 			content: text |  | ||||||
| 		// 		} |  | ||||||
| 		// 	] |  | ||||||
| 		// }); |  | ||||||
| 
 |  | ||||||
| 		let responseMessage; | 		let responseMessage; | ||||||
| 		if (messages.at(-1)?.role === 'assistant') { | 		if (messages.at(-1)?.role === 'assistant') { | ||||||
| 			responseMessage = messages.at(-1); | 			responseMessage = messages.at(-1); | ||||||
|  | @ -180,6 +178,11 @@ | ||||||
| 			while (true) { | 			while (true) { | ||||||
| 				const { value, done } = await reader.read(); | 				const { value, done } = await reader.read(); | ||||||
| 				if (done || stopResponseFlag) { | 				if (done || stopResponseFlag) { | ||||||
|  | 					if (stopResponseFlag) { | ||||||
|  | 						await cancelChatCompletion(localStorage.token, currentRequestId); | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					currentRequestId = null; | ||||||
| 					break; | 					break; | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
|  | @ -196,17 +199,21 @@ | ||||||
| 								let data = JSON.parse(line.replace(/^data: /, '')); | 								let data = JSON.parse(line.replace(/^data: /, '')); | ||||||
| 								console.log(data); | 								console.log(data); | ||||||
| 
 | 
 | ||||||
| 								if (responseMessage.content == '' && data.choices[0].delta.content == '\n') { | 								if ('request_id' in data) { | ||||||
| 									continue; | 									currentRequestId = data.request_id; | ||||||
| 								} else { | 								} else { | ||||||
| 									textareaElement.style.height = textareaElement.scrollHeight + 'px'; | 									if (responseMessage.content == '' && data.choices[0].delta.content == '\n') { | ||||||
|  | 										continue; | ||||||
|  | 									} else { | ||||||
|  | 										textareaElement.style.height = textareaElement.scrollHeight + 'px'; | ||||||
| 
 | 
 | ||||||
| 									responseMessage.content += data.choices[0].delta.content ?? ''; | 										responseMessage.content += data.choices[0].delta.content ?? ''; | ||||||
| 									messages = messages; | 										messages = messages; | ||||||
| 
 | 
 | ||||||
| 									textareaElement.style.height = textareaElement.scrollHeight + 'px'; | 										textareaElement.style.height = textareaElement.scrollHeight + 'px'; | ||||||
| 
 | 
 | ||||||
| 									await tick(); | 										await tick(); | ||||||
|  | 									} | ||||||
| 								} | 								} | ||||||
| 							} | 							} | ||||||
| 						} | 						} | ||||||
|  | @ -217,48 +224,6 @@ | ||||||
| 
 | 
 | ||||||
| 				scrollToBottom(); | 				scrollToBottom(); | ||||||
| 			} | 			} | ||||||
| 
 |  | ||||||
| 			// while (true) { |  | ||||||
| 			// 	const { value, done } = await reader.read(); |  | ||||||
| 			// 	if (done || stopResponseFlag) { |  | ||||||
| 			// 		if (stopResponseFlag) { |  | ||||||
| 			// 			await cancelChatCompletion(localStorage.token, currentRequestId); |  | ||||||
| 			// 		} |  | ||||||
| 
 |  | ||||||
| 			// 		currentRequestId = null; |  | ||||||
| 			// 		break; |  | ||||||
| 			// 	} |  | ||||||
| 
 |  | ||||||
| 			// 	try { |  | ||||||
| 			// 		let lines = value.split('\n'); |  | ||||||
| 
 |  | ||||||
| 			// 		for (const line of lines) { |  | ||||||
| 			// 			if (line !== '') { |  | ||||||
| 			// 				console.log(line); |  | ||||||
| 			// 				let data = JSON.parse(line); |  | ||||||
| 
 |  | ||||||
| 			// 				if ('detail' in data) { |  | ||||||
| 			// 					throw data; |  | ||||||
| 			// 				} |  | ||||||
| 
 |  | ||||||
| 			// 				if ('id' in data) { |  | ||||||
| 			// 					console.log(data); |  | ||||||
| 			// 					currentRequestId = data.id; |  | ||||||
| 			// 				} else { |  | ||||||
| 			// 					if (data.done == false) { |  | ||||||
| 			// 						text += data.message.content; |  | ||||||
| 			// 					} else { |  | ||||||
| 			// 						console.log('done'); |  | ||||||
| 			// 					} |  | ||||||
| 			// 				} |  | ||||||
| 			// 			} |  | ||||||
| 			// 		} |  | ||||||
| 			// 	} catch (error) { |  | ||||||
| 			// 		console.log(error); |  | ||||||
| 			// 	} |  | ||||||
| 
 |  | ||||||
| 			// 	scrollToBottom(); |  | ||||||
| 			// } |  | ||||||
| 		} | 		} | ||||||
| 	}; | 	}; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek