diff --git a/README.md b/README.md index dfa7c1a5..3a14b00a 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ Also check our sibling project, [OllamaHub](https://ollamahub.com/), where you c - 👍👎 **RLHF Annotation**: Empower your messages by rating them with thumbs up and thumbs down, facilitating the creation of datasets for Reinforcement Learning from Human Feedback (RLHF). Utilize your messages to train or fine-tune models, all while ensuring the confidentiality of locally saved data. +- 🏷️ **Conversation Tagging**: Effortlessly categorize and locate specific chats for quick reference and streamlined data collection. + - 📥🗑️ **Download/Delete Models**: Easily download or remove models directly from the web UI. - ⬆️ **GGUF File Model Creation**: Effortlessly create Ollama models by uploading GGUF files directly from the web UI. Streamlined process with options to upload from your machine or download GGUF files from Hugging Face. diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 1eeae85f..a94dc37f 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -5,6 +5,7 @@ from fastapi.concurrency import run_in_threadpool import requests import json +import uuid from pydantic import BaseModel from apps.web.models.users import Users @@ -26,6 +27,9 @@ app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL # TARGET_SERVER_URL = OLLAMA_API_BASE_URL +REQUEST_POOL = [] + + @app.get("/url") async def get_ollama_api_url(user=Depends(get_current_user)): if user and user.role == "admin": @@ -49,6 +53,16 @@ async def update_ollama_api_url( raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) +@app.get("/cancel/{request_id}") +async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)): + if user: + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) + return True + else: + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_current_user)): target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" @@ -74,7 +88,27 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): def get_request(): nonlocal r + + request_id = str(uuid.uuid4()) try: + REQUEST_POOL.append(request_id) + + def stream_content(): + try: + if path in ["chat"]: + yield json.dumps({"id": request_id, "done": False}) + "\n" + + for chunk in r.iter_content(chunk_size=8192): + if request_id in REQUEST_POOL: + yield chunk + else: + print("User: canceled request") + break + finally: + if hasattr(r, "close"): + r.close() + REQUEST_POOL.remove(request_id) + r = requests.request( method=request.method, url=target_url, @@ -85,8 +119,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): r.raise_for_status() + # r.close() + return StreamingResponse( - r.iter_content(chunk_size=8192), + stream_content(), status_code=r.status_code, headers=dict(r.headers), ) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index ef9330c5..780475ad 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -37,19 +37,16 @@ async def get_openai_url(user=Depends(get_current_user)): if user and user.role == "admin": return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} else: - raise HTTPException(status_code=401, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) @app.post("/url/update") -async def update_openai_url(form_data: UrlUpdateForm, - user=Depends(get_current_user)): +async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)): if user and user.role == "admin": app.state.OPENAI_API_BASE_URL = form_data.url return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} else: - raise HTTPException(status_code=401, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) @app.get("/key") @@ -57,19 +54,16 @@ async def get_openai_key(user=Depends(get_current_user)): if user and user.role == "admin": return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} else: - raise HTTPException(status_code=401, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) @app.post("/key/update") -async def update_openai_key(form_data: KeyUpdateForm, - user=Depends(get_current_user)): +async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)): if user and user.role == "admin": app.state.OPENAI_API_KEY = form_data.key return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} else: - raise HTTPException(status_code=401, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @@ -78,15 +72,29 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): print(target_url, app.state.OPENAI_API_KEY) if user.role not in ["user", "admin"]: - raise HTTPException(status_code=401, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) if app.state.OPENAI_API_KEY == "": - raise HTTPException(status_code=401, - detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) body = await request.body() - # headers = dict(request.headers) - # print(headers) + + # TODO: Remove below after gpt-4-vision fix from Open AI + # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) + try: + body = body.decode("utf-8") + body = json.loads(body) + + # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 + # This is a workaround until OpenAI fixes the issue with this model + if body.get("model") == "gpt-4-vision-preview": + if "max_tokens" not in body: + body["max_tokens"] = 4000 + print("Modified body_dict:", body) + + # Convert the modified body back to JSON + body = json.dumps(body) + except json.JSONDecodeError as e: + print("Error loading request body into a dictionary:", e) headers = {} headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" @@ -125,8 +133,8 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): if "openai" in app.state.OPENAI_API_BASE_URL and path == "models": response_data["data"] = list( - filter(lambda model: "gpt" in model["id"], - response_data["data"])) + filter(lambda model: "gpt" in model["id"], response_data["data"]) + ) return response_data except Exception as e: diff --git a/backend/apps/web/models/chats.py b/backend/apps/web/models/chats.py index bc4659de..4895740d 100644 --- a/backend/apps/web/models/chats.py +++ b/backend/apps/web/models/chats.py @@ -60,23 +60,23 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: - def __init__(self, db): self.db = db db.create_tables([Chat]) - def insert_new_chat(self, user_id: str, - form_data: ChatForm) -> Optional[ChatModel]: + def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: id = str(uuid.uuid4()) chat = ChatModel( **{ "id": id, "user_id": user_id, - "title": form_data.chat["title"] if "title" in - form_data.chat else "New Chat", + "title": form_data.chat["title"] + if "title" in form_data.chat + else "New Chat", "chat": json.dumps(form_data.chat), "timestamp": int(time.time()), - }) + } + ) result = Chat.create(**chat.model_dump()) return chat if result else None @@ -109,25 +109,37 @@ class ChatTable: except: return None - def get_chat_lists_by_user_id(self, - user_id: str, - skip: int = 0, - limit: int = 50) -> List[ChatModel]: + def get_chat_lists_by_user_id( + self, user_id: str, skip: int = 0, limit: int = 50 + ) -> List[ChatModel]: return [ - ChatModel(**model_to_dict(chat)) for chat in Chat.select().where( - Chat.user_id == user_id).order_by(Chat.timestamp.desc()) + ChatModel(**model_to_dict(chat)) + for chat in Chat.select() + .where(Chat.user_id == user_id) + .order_by(Chat.timestamp.desc()) # .limit(limit) # .offset(skip) ] - def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]: + def get_chat_lists_by_chat_ids( + self, chat_ids: List[str], skip: int = 0, limit: int = 50 + ) -> List[ChatModel]: return [ - ChatModel(**model_to_dict(chat)) for chat in Chat.select().where( - Chat.user_id == user_id).order_by(Chat.timestamp.desc()) + ChatModel(**model_to_dict(chat)) + for chat in Chat.select() + .where(Chat.id.in_(chat_ids)) + .order_by(Chat.timestamp.desc()) ] - def get_chat_by_id_and_user_id(self, id: str, - user_id: str) -> Optional[ChatModel]: + def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]: + return [ + ChatModel(**model_to_dict(chat)) + for chat in Chat.select() + .where(Chat.user_id == user_id) + .order_by(Chat.timestamp.desc()) + ] + + def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: try: chat = Chat.get(Chat.id == id, Chat.user_id == user_id) return ChatModel(**model_to_dict(chat)) @@ -142,8 +154,7 @@ class ChatTable: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: - query = Chat.delete().where((Chat.id == id) - & (Chat.user_id == user_id)) + query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id)) query.execute() # Remove the rows, return number of rows removed. return True diff --git a/backend/apps/web/models/tags.py b/backend/apps/web/models/tags.py new file mode 100644 index 00000000..c14658cf --- /dev/null +++ b/backend/apps/web/models/tags.py @@ -0,0 +1,206 @@ +from pydantic import BaseModel +from typing import List, Union, Optional +from peewee import * +from playhouse.shortcuts import model_to_dict + +import json +import uuid +import time + +from apps.web.internal.db import DB + +#################### +# Tag DB Schema +#################### + + +class Tag(Model): + id = CharField(unique=True) + name = CharField() + user_id = CharField() + data = TextField(null=True) + + class Meta: + database = DB + + +class ChatIdTag(Model): + id = CharField(unique=True) + tag_name = CharField() + chat_id = CharField() + user_id = CharField() + timestamp = DateField() + + class Meta: + database = DB + + +class TagModel(BaseModel): + id: str + name: str + user_id: str + data: Optional[str] = None + + +class ChatIdTagModel(BaseModel): + id: str + tag_name: str + chat_id: str + user_id: str + timestamp: int + + +#################### +# Forms +#################### + + +class ChatIdTagForm(BaseModel): + tag_name: str + chat_id: str + + +class TagChatIdsResponse(BaseModel): + chat_ids: List[str] + + +class ChatTagsResponse(BaseModel): + tags: List[str] + + +class TagTable: + def __init__(self, db): + self.db = db + db.create_tables([Tag, ChatIdTag]) + + def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: + id = str(uuid.uuid4()) + tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) + try: + result = Tag.create(**tag.model_dump()) + if result: + return tag + else: + return None + except Exception as e: + return None + + def get_tag_by_name_and_user_id( + self, name: str, user_id: str + ) -> Optional[TagModel]: + try: + tag = Tag.get(Tag.name == name, Tag.user_id == user_id) + return TagModel(**model_to_dict(tag)) + except Exception as e: + return None + + def add_tag_to_chat( + self, user_id: str, form_data: ChatIdTagForm + ) -> Optional[ChatIdTagModel]: + tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id) + if tag == None: + tag = self.insert_new_tag(form_data.tag_name, user_id) + + id = str(uuid.uuid4()) + chatIdTag = ChatIdTagModel( + **{ + "id": id, + "user_id": user_id, + "chat_id": form_data.chat_id, + "tag_name": tag.name, + "timestamp": int(time.time()), + } + ) + try: + result = ChatIdTag.create(**chatIdTag.model_dump()) + if result: + return chatIdTag + else: + return None + except: + return None + + def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: + tag_names = [ + ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name + for chat_id_tag in ChatIdTag.select() + .where(ChatIdTag.user_id == user_id) + .order_by(ChatIdTag.timestamp.desc()) + ] + + return [ + TagModel(**model_to_dict(tag)) + for tag in Tag.select().where(Tag.name.in_(tag_names)) + ] + + def get_tags_by_chat_id_and_user_id( + self, chat_id: str, user_id: str + ) -> List[TagModel]: + tag_names = [ + ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name + for chat_id_tag in ChatIdTag.select() + .where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id)) + .order_by(ChatIdTag.timestamp.desc()) + ] + + return [ + TagModel(**model_to_dict(tag)) + for tag in Tag.select().where(Tag.name.in_(tag_names)) + ] + + def get_chat_ids_by_tag_name_and_user_id( + self, tag_name: str, user_id: str + ) -> Optional[ChatIdTagModel]: + return [ + ChatIdTagModel(**model_to_dict(chat_id_tag)) + for chat_id_tag in ChatIdTag.select() + .where((ChatIdTag.user_id == user_id) & (ChatIdTag.tag_name == tag_name)) + .order_by(ChatIdTag.timestamp.desc()) + ] + + def count_chat_ids_by_tag_name_and_user_id( + self, tag_name: str, user_id: str + ) -> int: + return ( + ChatIdTag.select() + .where((ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)) + .count() + ) + + def delete_tag_by_tag_name_and_chat_id_and_user_id( + self, tag_name: str, chat_id: str, user_id: str + ) -> bool: + try: + query = ChatIdTag.delete().where( + (ChatIdTag.tag_name == tag_name) + & (ChatIdTag.chat_id == chat_id) + & (ChatIdTag.user_id == user_id) + ) + res = query.execute() # Remove the rows, return number of rows removed. + print(res) + + tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) + if tag_count == 0: + # Remove tag item from Tag col as well + query = Tag.delete().where( + (Tag.name == tag_name) & (Tag.user_id == user_id) + ) + query.execute() # Remove the rows, return number of rows removed. + + return True + except Exception as e: + print("delete_tag", e) + return False + + def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool: + tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id) + + for tag in tags: + self.delete_tag_by_tag_name_and_chat_id_and_user_id( + tag.tag_name, chat_id, user_id + ) + + return True + + +Tags = TagTable(DB) diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index f245601d..a0772223 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -91,42 +91,40 @@ async def signin(form_data: SigninForm): @router.post("/signup", response_model=SigninResponse) async def signup(request: Request, form_data: SignupForm): - if request.app.state.ENABLE_SIGNUP: - if validate_email_format(form_data.email.lower()): - if not Users.get_user_by_email(form_data.email.lower()): - try: - role = "admin" if Users.get_num_users() == 0 else "pending" - hashed = get_password_hash(form_data.password) - user = Auths.insert_new_auth(form_data.email.lower(), - hashed, form_data.name, role) - - if user: - token = create_token(data={"email": user.email}) - # response.set_cookie(key='token', value=token, httponly=True) - - return { - "token": token, - "token_type": "Bearer", - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - "profile_image_url": user.profile_image_url, - } - else: - raise HTTPException( - 500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) - except Exception as err: - raise HTTPException(500, - detail=ERROR_MESSAGES.DEFAULT(err)) - else: - raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) - else: - raise HTTPException(400, - detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) - else: + if not request.app.state.ENABLE_SIGNUP: raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + if not validate_email_format(form_data.email.lower()): + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) + + if Users.get_user_by_email(form_data.email.lower()): + raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) + + try: + role = "admin" if Users.get_num_users() == 0 else "pending" + hashed = get_password_hash(form_data.password) + user = Auths.insert_new_auth(form_data.email.lower(), + hashed, form_data.name, role) + if user: + token = create_token(data={"email": user.email}) + # response.set_cookie(key='token', value=token, httponly=True) + + return { + "token": token, + "token_type": "Bearer", + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + "profile_image_url": user.profile_image_url, + } + else: + raise HTTPException( + 500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) + except Exception as err: + raise HTTPException(500, + detail=ERROR_MESSAGES.DEFAULT(err)) ############################ # ToggleSignUp diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py index e97e1473..29214229 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/web/routers/chats.py @@ -16,6 +16,15 @@ from apps.web.models.chats import ( Chats, ) + +from apps.web.models.tags import ( + TagModel, + ChatIdTagModel, + ChatIdTagForm, + ChatTagsResponse, + Tags, +) + from utils.utils import ( bearer_scheme, ) @@ -65,6 +74,42 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): ) +############################ +# GetAllTags +############################ + + +@router.get("/tags/all", response_model=List[TagModel]) +async def get_all_tags(user=Depends(get_current_user)): + try: + tags = Tags.get_tags_by_user_id(user.id) + return tags + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + +############################ +# GetChatsByTags +############################ + + +@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse]) +async def get_user_chats_by_tag_name( + tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50 +): + chat_ids = [ + chat_id_tag.chat_id + for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id) + ] + + print(chat_ids) + + return Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit) + + ############################ # GetChatById ############################ @@ -115,6 +160,88 @@ async def delete_chat_by_id(id: str, user=Depends(get_current_user)): return result +############################ +# GetChatTagsById +############################ + + +@router.get("/{id}/tags", response_model=List[TagModel]) +async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)): + tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) + + if tags != None: + return tags + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + ) + + +############################ +# AddChatTagById +############################ + + +@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel]) +async def add_chat_tag_by_id( + id: str, form_data: ChatIdTagForm, user=Depends(get_current_user) +): + tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) + + if form_data.tag_name not in tags: + tag = Tags.add_tag_to_chat(user.id, form_data) + + if tag: + return tag + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() + ) + + +############################ +# DeleteChatTagById +############################ + + +@router.delete("/{id}/tags", response_model=Optional[bool]) +async def delete_chat_tag_by_id( + id: str, form_data: ChatIdTagForm, user=Depends(get_current_user) +): + result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id( + form_data.tag_name, id, user.id + ) + + if result: + return result + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + ) + + +############################ +# DeleteAllChatTagsById +############################ + + +@router.delete("/{id}/tags/all", response_model=Optional[bool]) +async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)): + result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id) + + if result: + return result + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + ) + + ############################ # DeleteAllChats ############################ diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index 0eddf5b4..7c515f11 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -93,6 +93,68 @@ export const getAllChats = async (token: string) => { return res; }; +export const getAllChatTags = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/all`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getChatListByTagName = async (token: string = '', tagName: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/tag/${tagName}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getChatById = async (token: string, id: string) => { let error = null; @@ -192,6 +254,141 @@ export const deleteChatById = async (token: string, id: string) => { return res; }; +export const getTagsById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/tags`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const addTagById = async (token: string, id: string, tagName: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/tags`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + tag_name: tagName, + chat_id: id + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteTagById = async (token: string, id: string, tagName: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/tags`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + tag_name: tagName, + chat_id: id + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; +export const deleteTagsById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/tags/all`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const deleteAllChats = async (token: string) => { let error = null; diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index e863e51e..62501966 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -206,9 +206,11 @@ export const generatePrompt = async (token: string = '', model: string, conversa }; export const generateChatCompletion = async (token: string = '', body: object) => { + let controller = new AbortController(); let error = null; const res = await fetch(`${OLLAMA_API_BASE_URL}/chat`, { + signal: controller.signal, method: 'POST', headers: { 'Content-Type': 'text/event-stream', @@ -224,6 +226,27 @@ export const generateChatCompletion = async (token: string = '', body: object) = throw error; } + return [res, controller]; +}; + +export const cancelChatCompletion = async (token: string = '', requestId: string) => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, { + method: 'GET', + headers: { + 'Content-Type': 'text/event-stream', + Authorization: `Bearer ${token}` + } + }).catch((err) => { + error = err; + return null; + }); + + if (error) { + throw error; + } + return res; }; diff --git a/src/lib/components/chat/Messages/CodeBlock.svelte b/src/lib/components/chat/Messages/CodeBlock.svelte new file mode 100644 index 00000000..c5290547 --- /dev/null +++ b/src/lib/components/chat/Messages/CodeBlock.svelte @@ -0,0 +1,38 @@ + + +{#if code} +
{@html highlightedCode || code}
+