forked from open-webui/open-webui
		
	Merge pull request #498 from bhulston/fix/chat-imports
Feat: Add ChatGPT import functionality
This commit is contained in:
		
						commit
						578e78cb39
					
				
					 4 changed files with 110 additions and 21 deletions
				
			
		|  | @ -17,7 +17,8 @@ from apps.web.models.chats import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| from utils.utils import ( | from utils.utils import ( | ||||||
|     bearer_scheme, ) |     bearer_scheme, | ||||||
|  | ) | ||||||
| from constants import ERROR_MESSAGES | from constants import ERROR_MESSAGES | ||||||
| 
 | 
 | ||||||
| router = APIRouter() | router = APIRouter() | ||||||
|  | @ -29,7 +30,8 @@ router = APIRouter() | ||||||
| 
 | 
 | ||||||
| @router.get("/", response_model=List[ChatTitleIdResponse]) | @router.get("/", response_model=List[ChatTitleIdResponse]) | ||||||
| async def get_user_chats( | async def get_user_chats( | ||||||
|         user=Depends(get_current_user), skip: int = 0, limit: int = 50): |     user=Depends(get_current_user), skip: int = 0, limit: int = 50 | ||||||
|  | ): | ||||||
|     return Chats.get_chat_lists_by_user_id(user.id, skip, limit) |     return Chats.get_chat_lists_by_user_id(user.id, skip, limit) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -41,9 +43,8 @@ async def get_user_chats( | ||||||
| @router.get("/all", response_model=List[ChatResponse]) | @router.get("/all", response_model=List[ChatResponse]) | ||||||
| async def get_all_user_chats(user=Depends(get_current_user)): | async def get_all_user_chats(user=Depends(get_current_user)): | ||||||
|     return [ |     return [ | ||||||
|         ChatResponse(**{ |         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | ||||||
|             **chat.model_dump(), "chat": json.loads(chat.chat) |         for chat in Chats.get_all_chats_by_user_id(user.id) | ||||||
|         }) for chat in Chats.get_all_chats_by_user_id(user.id) |  | ||||||
|     ] |     ] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -54,8 +55,14 @@ async def get_all_user_chats(user=Depends(get_current_user)): | ||||||
| 
 | 
 | ||||||
| @router.post("/new", response_model=Optional[ChatResponse]) | @router.post("/new", response_model=Optional[ChatResponse]) | ||||||
| async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): | async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): | ||||||
|     chat = Chats.insert_new_chat(user.id, form_data) |     try: | ||||||
|     return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) |         chat = Chats.insert_new_chat(user.id, form_data) | ||||||
|  |         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | ||||||
|  |     except Exception as e: | ||||||
|  |         print(e) | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| ############################ | ############################ | ||||||
|  | @ -68,12 +75,11 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)): | ||||||
|     chat = Chats.get_chat_by_id_and_user_id(id, user.id) |     chat = Chats.get_chat_by_id_and_user_id(id, user.id) | ||||||
| 
 | 
 | ||||||
|     if chat: |     if chat: | ||||||
|         return ChatResponse(**{ |         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | ||||||
|             **chat.model_dump(), "chat": json.loads(chat.chat) |  | ||||||
|         }) |  | ||||||
|     else: |     else: | ||||||
|         raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, |         raise HTTPException( | ||||||
|                             detail=ERROR_MESSAGES.NOT_FOUND) |             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| ############################ | ############################ | ||||||
|  | @ -82,17 +88,15 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @router.post("/{id}", response_model=Optional[ChatResponse]) | @router.post("/{id}", response_model=Optional[ChatResponse]) | ||||||
| async def update_chat_by_id(id: str, | async def update_chat_by_id( | ||||||
|                             form_data: ChatForm, |     id: str, form_data: ChatForm, user=Depends(get_current_user) | ||||||
|                             user=Depends(get_current_user)): | ): | ||||||
|     chat = Chats.get_chat_by_id_and_user_id(id, user.id) |     chat = Chats.get_chat_by_id_and_user_id(id, user.id) | ||||||
|     if chat: |     if chat: | ||||||
|         updated_chat = {**json.loads(chat.chat), **form_data.chat} |         updated_chat = {**json.loads(chat.chat), **form_data.chat} | ||||||
| 
 | 
 | ||||||
|         chat = Chats.update_chat_by_id(id, updated_chat) |         chat = Chats.update_chat_by_id(id, updated_chat) | ||||||
|         return ChatResponse(**{ |         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | ||||||
|             **chat.model_dump(), "chat": json.loads(chat.chat) |  | ||||||
|         }) |  | ||||||
|     else: |     else: | ||||||
|         raise HTTPException( |         raise HTTPException( | ||||||
|             status_code=status.HTTP_401_UNAUTHORIZED, |             status_code=status.HTTP_401_UNAUTHORIZED, | ||||||
|  |  | ||||||
|  | @ -21,7 +21,7 @@ | ||||||
| 	import { WEB_UI_VERSION, WEBUI_API_BASE_URL } from '$lib/constants'; | 	import { WEB_UI_VERSION, WEBUI_API_BASE_URL } from '$lib/constants'; | ||||||
| 
 | 
 | ||||||
| 	import { config, models, settings, user, chats } from '$lib/stores'; | 	import { config, models, settings, user, chats } from '$lib/stores'; | ||||||
| 	import { splitStream, getGravatarURL } from '$lib/utils'; | 	import { splitStream, getGravatarURL, getImportOrigin, convertOpenAIChats } from '$lib/utils'; | ||||||
| 
 | 
 | ||||||
| 	import Advanced from './Settings/Advanced.svelte'; | 	import Advanced from './Settings/Advanced.svelte'; | ||||||
| 	import Modal from '../common/Modal.svelte'; | 	import Modal from '../common/Modal.svelte'; | ||||||
|  | @ -132,6 +132,13 @@ | ||||||
| 		reader.onload = (event) => { | 		reader.onload = (event) => { | ||||||
| 			let chats = JSON.parse(event.target.result); | 			let chats = JSON.parse(event.target.result); | ||||||
| 			console.log(chats); | 			console.log(chats); | ||||||
|  | 			if (getImportOrigin(chats) == 'openai') { | ||||||
|  | 				try { | ||||||
|  | 					chats = convertOpenAIChats(chats); | ||||||
|  | 				} catch (error) { | ||||||
|  | 					console.log('Unable to import chats:', error); | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
| 			importChats(chats); | 			importChats(chats); | ||||||
| 		}; | 		}; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -192,3 +192,74 @@ export const calculateSHA256 = async (file) => { | ||||||
| 		throw error; | 		throw error; | ||||||
| 	} | 	} | ||||||
| }; | }; | ||||||
|  | 
 | ||||||
|  | export const getImportOrigin = (_chats) => { | ||||||
|  | 	// Check what external service chat imports are from
 | ||||||
|  | 	if ('mapping' in _chats[0]) { | ||||||
|  | 		return 'openai'; | ||||||
|  | 	} | ||||||
|  | 	return 'webui'; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | const convertOpenAIMessages = (convo) => { | ||||||
|  | 	// Parse OpenAI chat messages and create chat dictionary for creating new chats
 | ||||||
|  | 	const mapping = convo['mapping']; | ||||||
|  | 	const messages = []; | ||||||
|  | 	let currentId = ''; | ||||||
|  | 
 | ||||||
|  | 	for (let message_id in mapping) { | ||||||
|  | 		const message = mapping[message_id]; | ||||||
|  | 		currentId = message_id; | ||||||
|  | 		if (message['message'] == null || message['message']['content']['parts'][0] == '') { | ||||||
|  | 			// Skip chat messages with no content
 | ||||||
|  | 			continue; | ||||||
|  | 		} else { | ||||||
|  | 			const new_chat = { | ||||||
|  | 				id: message_id, | ||||||
|  | 				parentId: messages.length > 0 && message['parent'] in mapping ? message['parent'] : null, | ||||||
|  | 				childrenIds: message['children'] || [], | ||||||
|  | 				role: message['message']?.['author']?.['role'] !== 'user' ? 'assistant' : 'user', | ||||||
|  | 				content: message['message']?.['content']?.['parts']?.[0] || '', | ||||||
|  | 				model: 'gpt-3.5-turbo', | ||||||
|  | 				done: true, | ||||||
|  | 				context: null | ||||||
|  | 			}; | ||||||
|  | 			messages.push(new_chat); | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	let history = {}; | ||||||
|  | 	messages.forEach((obj) => (history[obj.id] = obj)); | ||||||
|  | 
 | ||||||
|  | 	const chat = { | ||||||
|  | 		history: { | ||||||
|  | 			currentId: currentId, | ||||||
|  | 			messages: history // Need to convert this to not a list and instead a json object
 | ||||||
|  | 		}, | ||||||
|  | 		models: ['gpt-3.5-turbo'], | ||||||
|  | 		messages: messages, | ||||||
|  | 		options: {}, | ||||||
|  | 		timestamp: convo['create_time'], | ||||||
|  | 		title: convo['title'] ?? 'New Chat' | ||||||
|  | 	}; | ||||||
|  | 	return chat; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | export const convertOpenAIChats = (_chats) => { | ||||||
|  | 	// Create a list of dictionaries with each conversation from import
 | ||||||
|  | 	const chats = []; | ||||||
|  | 	for (let convo of _chats) { | ||||||
|  | 		const chat = convertOpenAIMessages(convo); | ||||||
|  | 
 | ||||||
|  | 		if (Object.keys(chat.history.messages).length > 0) { | ||||||
|  | 			chats.push({ | ||||||
|  | 				id: convo['id'], | ||||||
|  | 				user_id: '', | ||||||
|  | 				title: convo['title'], | ||||||
|  | 				chat: chat, | ||||||
|  | 				timestamp: convo['timestamp'] | ||||||
|  | 			}); | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return chats; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | @ -200,6 +200,13 @@ | ||||||
| 					await chatId.set('local'); | 					await chatId.set('local'); | ||||||
| 				} | 				} | ||||||
| 				await tick(); | 				await tick(); | ||||||
|  | 			} else if (chat.chat["models"] != selectedModels) { | ||||||
|  | 				// If model is not saved in DB, then save selectedmodel when message is sent | ||||||
|  | 
 | ||||||
|  | 				chat = await updateChatById(localStorage.token, $chatId, { | ||||||
|  | 						models: selectedModels | ||||||
|  | 					}); | ||||||
|  | 				await chats.set(await getChatList(localStorage.token)); | ||||||
| 			} | 			} | ||||||
| 			 | 			 | ||||||
| 			// Reset chat input textarea | 			// Reset chat input textarea | ||||||
|  | @ -696,7 +703,7 @@ | ||||||
| 	<div class="min-h-screen w-full flex justify-center"> | 	<div class="min-h-screen w-full flex justify-center"> | ||||||
| 		<div class=" py-2.5 flex flex-col justify-between w-full"> | 		<div class=" py-2.5 flex flex-col justify-between w-full"> | ||||||
| 			<div class="max-w-2xl mx-auto w-full px-3 md:px-0 mt-10"> | 			<div class="max-w-2xl mx-auto w-full px-3 md:px-0 mt-10"> | ||||||
| 				<ModelSelector bind:selectedModels disabled={messages.length > 0} /> | 				<ModelSelector bind:selectedModels disabled={messages.length > 0 && !selectedModels.includes('')} /> | ||||||
| 			</div> | 			</div> | ||||||
| 
 | 
 | ||||||
| 			<div class=" h-full mt-10 mb-32 w-full flex flex-col"> | 			<div class=" h-full mt-10 mb-32 w-full flex flex-col"> | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy Jaeryang Baek
						Timothy Jaeryang Baek