forked from open-webui/open-webui
feat: add backend functions for sharing chats
This commit is contained in:
parent
a363c1f2f1
commit
94976e5ed3
2 changed files with 112 additions and 0 deletions
|
@ -20,6 +20,7 @@ class Chat(Model):
|
||||||
title = CharField()
|
title = CharField()
|
||||||
chat = TextField() # Save Chat JSON as Text
|
chat = TextField() # Save Chat JSON as Text
|
||||||
timestamp = DateField()
|
timestamp = DateField()
|
||||||
|
share_id = CharField(null=True, unique=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
database = DB
|
database = DB
|
||||||
|
@ -31,6 +32,7 @@ class ChatModel(BaseModel):
|
||||||
title: str
|
title: str
|
||||||
chat: str
|
chat: str
|
||||||
timestamp: int # timestamp in epoch
|
timestamp: int # timestamp in epoch
|
||||||
|
share_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
@ -52,6 +54,7 @@ class ChatResponse(BaseModel):
|
||||||
title: str
|
title: str
|
||||||
chat: dict
|
chat: dict
|
||||||
timestamp: int # timestamp in epoch
|
timestamp: int # timestamp in epoch
|
||||||
|
share_id: Optional[str] = None # id of the chat to be shared
|
||||||
|
|
||||||
|
|
||||||
class ChatTitleIdResponse(BaseModel):
|
class ChatTitleIdResponse(BaseModel):
|
||||||
|
@ -95,6 +98,44 @@ class ChatTable:
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def insert_shared_chat(self, chat_id: str) -> Optional[ChatModel]:
|
||||||
|
# Get the existing chat to share
|
||||||
|
chat = Chat.get(Chat.id == chat_id)
|
||||||
|
# Check if the chat is already shared
|
||||||
|
if chat.share_id:
|
||||||
|
return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
|
||||||
|
# Create a new chat with the same data, but with a new ID
|
||||||
|
shared_chat = ChatModel(
|
||||||
|
**{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"user_id": "shared",
|
||||||
|
"title": chat.title,
|
||||||
|
"chat": chat.chat,
|
||||||
|
"timestamp": int(time.time()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
shared_result = Chat.create(**shared_chat.model_dump())
|
||||||
|
# Update the original chat with the share_id
|
||||||
|
result = (
|
||||||
|
Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
return shared_chat if (shared_result and result) else None
|
||||||
|
|
||||||
|
def update_chat_share_id_by_id(
|
||||||
|
self, od: str, share_id: Optional[str]
|
||||||
|
) -> Optional[ChatModel]:
|
||||||
|
try:
|
||||||
|
query = Chat.update(
|
||||||
|
share_id=share_id,
|
||||||
|
).where(Chat.id == id)
|
||||||
|
query.execute()
|
||||||
|
|
||||||
|
chat = Chat.get(Chat.id == id)
|
||||||
|
return ChatModel(**model_to_dict(chat))
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_chat_lists_by_user_id(
|
def get_chat_lists_by_user_id(
|
||||||
self, user_id: str, skip: int = 0, limit: int = 50
|
self, user_id: str, skip: int = 0, limit: int = 50
|
||||||
) -> List[ChatModel]:
|
) -> List[ChatModel]:
|
||||||
|
|
|
@ -189,6 +189,77 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# ShareChatById
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{id}/share", response_model=Optional[ChatResponse])
|
||||||
|
async def share_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||||
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||||
|
if chat:
|
||||||
|
if chat.share_id:
|
||||||
|
shared_chat = Chats.get_chat_by_id_and_user_id(chat.share_id, "shared")
|
||||||
|
return ChatResponse(
|
||||||
|
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
|
||||||
|
)
|
||||||
|
|
||||||
|
shared_chat = Chats.insert_shared_chat(chat.id)
|
||||||
|
if not shared_chat:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# DeletedSharedChatById
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{id}/share", response_model=Optional[bool])
|
||||||
|
async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||||
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||||
|
if chat:
|
||||||
|
if not chat.share_id:
|
||||||
|
return False
|
||||||
|
result = Chats.delete_chat_by_id_and_user_id(chat.share_id, "shared")
|
||||||
|
update_result = Chats.update_chat_share_id_by_id(chat.id, None)
|
||||||
|
|
||||||
|
return result and update_result
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# GetSharedChatById
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/share/{id}", response_model=Optional[ChatResponse])
|
||||||
|
async def get_shared_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||||
|
chat = Chats.get_chat_by_id_and_user_id(id, "shared")
|
||||||
|
|
||||||
|
if chat:
|
||||||
|
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# GetChatTagsById
|
# GetChatTagsById
|
||||||
############################
|
############################
|
||||||
|
|
Loading…
Reference in a new issue