forked from open-webui/open-webui
		
	Merge pull request #498 from bhulston/fix/chat-imports
Feat: Add ChatGPT import functionality
This commit is contained in:
		
						commit
						578e78cb39
					
				
					 4 changed files with 110 additions and 21 deletions
				
			
		|  | @ -17,7 +17,8 @@ from apps.web.models.chats import ( | |||
| ) | ||||
| 
 | ||||
| from utils.utils import ( | ||||
|     bearer_scheme, ) | ||||
|     bearer_scheme, | ||||
| ) | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| router = APIRouter() | ||||
|  | @ -29,7 +30,8 @@ router = APIRouter() | |||
| 
 | ||||
| @router.get("/", response_model=List[ChatTitleIdResponse]) | ||||
| async def get_user_chats( | ||||
|         user=Depends(get_current_user), skip: int = 0, limit: int = 50): | ||||
|     user=Depends(get_current_user), skip: int = 0, limit: int = 50 | ||||
| ): | ||||
|     return Chats.get_chat_lists_by_user_id(user.id, skip, limit) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -41,9 +43,8 @@ async def get_user_chats( | |||
| @router.get("/all", response_model=List[ChatResponse]) | ||||
| async def get_all_user_chats(user=Depends(get_current_user)): | ||||
|     return [ | ||||
|         ChatResponse(**{ | ||||
|             **chat.model_dump(), "chat": json.loads(chat.chat) | ||||
|         }) for chat in Chats.get_all_chats_by_user_id(user.id) | ||||
|         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | ||||
|         for chat in Chats.get_all_chats_by_user_id(user.id) | ||||
|     ] | ||||
| 
 | ||||
| 
 | ||||
|  | @ -54,8 +55,14 @@ async def get_all_user_chats(user=Depends(get_current_user)): | |||
| 
 | ||||
| @router.post("/new", response_model=Optional[ChatResponse]) | ||||
| async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): | ||||
|     try: | ||||
|         chat = Chats.insert_new_chat(user.id, form_data) | ||||
|         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
|  | @ -68,12 +75,11 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)): | |||
|     chat = Chats.get_chat_by_id_and_user_id(id, user.id) | ||||
| 
 | ||||
|     if chat: | ||||
|         return ChatResponse(**{ | ||||
|             **chat.model_dump(), "chat": json.loads(chat.chat) | ||||
|         }) | ||||
|         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | ||||
|     else: | ||||
|         raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|                             detail=ERROR_MESSAGES.NOT_FOUND) | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
|  | @ -82,17 +88,15 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)): | |||
| 
 | ||||
| 
 | ||||
| @router.post("/{id}", response_model=Optional[ChatResponse]) | ||||
| async def update_chat_by_id(id: str, | ||||
|                             form_data: ChatForm, | ||||
|                             user=Depends(get_current_user)): | ||||
| async def update_chat_by_id( | ||||
|     id: str, form_data: ChatForm, user=Depends(get_current_user) | ||||
| ): | ||||
|     chat = Chats.get_chat_by_id_and_user_id(id, user.id) | ||||
|     if chat: | ||||
|         updated_chat = {**json.loads(chat.chat), **form_data.chat} | ||||
| 
 | ||||
|         chat = Chats.update_chat_by_id(id, updated_chat) | ||||
|         return ChatResponse(**{ | ||||
|             **chat.model_dump(), "chat": json.loads(chat.chat) | ||||
|         }) | ||||
|         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|  |  | |||
|  | @ -21,7 +21,7 @@ | |||
| 	import { WEB_UI_VERSION, WEBUI_API_BASE_URL } from '$lib/constants'; | ||||
| 
 | ||||
| 	import { config, models, settings, user, chats } from '$lib/stores'; | ||||
| 	import { splitStream, getGravatarURL } from '$lib/utils'; | ||||
| 	import { splitStream, getGravatarURL, getImportOrigin, convertOpenAIChats } from '$lib/utils'; | ||||
| 
 | ||||
| 	import Advanced from './Settings/Advanced.svelte'; | ||||
| 	import Modal from '../common/Modal.svelte'; | ||||
|  | @ -132,6 +132,13 @@ | |||
| 		reader.onload = (event) => { | ||||
| 			let chats = JSON.parse(event.target.result); | ||||
| 			console.log(chats); | ||||
| 			if (getImportOrigin(chats) == 'openai') { | ||||
| 				try { | ||||
| 					chats = convertOpenAIChats(chats); | ||||
| 				} catch (error) { | ||||
| 					console.log('Unable to import chats:', error); | ||||
| 				} | ||||
| 			} | ||||
| 			importChats(chats); | ||||
| 		}; | ||||
| 
 | ||||
|  |  | |||
|  | @ -192,3 +192,74 @@ export const calculateSHA256 = async (file) => { | |||
| 		throw error; | ||||
| 	} | ||||
| }; | ||||
| 
 | ||||
| export const getImportOrigin = (_chats) => { | ||||
| 	// Check what external service chat imports are from
 | ||||
| 	if ('mapping' in _chats[0]) { | ||||
| 		return 'openai'; | ||||
| 	} | ||||
| 	return 'webui'; | ||||
| }; | ||||
| 
 | ||||
| const convertOpenAIMessages = (convo) => { | ||||
| 	// Parse OpenAI chat messages and create chat dictionary for creating new chats
 | ||||
| 	const mapping = convo['mapping']; | ||||
| 	const messages = []; | ||||
| 	let currentId = ''; | ||||
| 
 | ||||
| 	for (let message_id in mapping) { | ||||
| 		const message = mapping[message_id]; | ||||
| 		currentId = message_id; | ||||
| 		if (message['message'] == null || message['message']['content']['parts'][0] == '') { | ||||
| 			// Skip chat messages with no content
 | ||||
| 			continue; | ||||
| 		} else { | ||||
| 			const new_chat = { | ||||
| 				id: message_id, | ||||
| 				parentId: messages.length > 0 && message['parent'] in mapping ? message['parent'] : null, | ||||
| 				childrenIds: message['children'] || [], | ||||
| 				role: message['message']?.['author']?.['role'] !== 'user' ? 'assistant' : 'user', | ||||
| 				content: message['message']?.['content']?.['parts']?.[0] || '', | ||||
| 				model: 'gpt-3.5-turbo', | ||||
| 				done: true, | ||||
| 				context: null | ||||
| 			}; | ||||
| 			messages.push(new_chat); | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	let history = {}; | ||||
| 	messages.forEach((obj) => (history[obj.id] = obj)); | ||||
| 
 | ||||
| 	const chat = { | ||||
| 		history: { | ||||
| 			currentId: currentId, | ||||
| 			messages: history // Need to convert this to not a list and instead a json object
 | ||||
| 		}, | ||||
| 		models: ['gpt-3.5-turbo'], | ||||
| 		messages: messages, | ||||
| 		options: {}, | ||||
| 		timestamp: convo['create_time'], | ||||
| 		title: convo['title'] ?? 'New Chat' | ||||
| 	}; | ||||
| 	return chat; | ||||
| }; | ||||
| 
 | ||||
| export const convertOpenAIChats = (_chats) => { | ||||
| 	// Create a list of dictionaries with each conversation from import
 | ||||
| 	const chats = []; | ||||
| 	for (let convo of _chats) { | ||||
| 		const chat = convertOpenAIMessages(convo); | ||||
| 
 | ||||
| 		if (Object.keys(chat.history.messages).length > 0) { | ||||
| 			chats.push({ | ||||
| 				id: convo['id'], | ||||
| 				user_id: '', | ||||
| 				title: convo['title'], | ||||
| 				chat: chat, | ||||
| 				timestamp: convo['timestamp'] | ||||
| 			}); | ||||
| 		} | ||||
| 	} | ||||
| 	return chats; | ||||
| }; | ||||
|  |  | |||
|  | @ -200,6 +200,13 @@ | |||
| 					await chatId.set('local'); | ||||
| 				} | ||||
| 				await tick(); | ||||
| 			} else if (chat.chat["models"] != selectedModels) { | ||||
| 				// If model is not saved in DB, then save selectedmodel when message is sent | ||||
| 
 | ||||
| 				chat = await updateChatById(localStorage.token, $chatId, { | ||||
| 						models: selectedModels | ||||
| 					}); | ||||
| 				await chats.set(await getChatList(localStorage.token)); | ||||
| 			} | ||||
| 			 | ||||
| 			// Reset chat input textarea | ||||
|  | @ -696,7 +703,7 @@ | |||
| 	<div class="min-h-screen w-full flex justify-center"> | ||||
| 		<div class=" py-2.5 flex flex-col justify-between w-full"> | ||||
| 			<div class="max-w-2xl mx-auto w-full px-3 md:px-0 mt-10"> | ||||
| 				<ModelSelector bind:selectedModels disabled={messages.length > 0} /> | ||||
| 				<ModelSelector bind:selectedModels disabled={messages.length > 0 && !selectedModels.includes('')} /> | ||||
| 			</div> | ||||
| 
 | ||||
| 			<div class=" h-full mt-10 mb-32 w-full flex flex-col"> | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy Jaeryang Baek
						Timothy Jaeryang Baek