diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 6da870ea..85bc995a 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -37,7 +37,7 @@ from typing import Optional import uuid import time -from utils.misc import calculate_sha256 +from utils.misc import calculate_sha256, calculate_sha256_string from utils.utils import get_current_user from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from constants import ERROR_MESSAGES @@ -124,10 +124,15 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): try: loader = WebBaseLoader(form_data.url) data = loader.load() - store_data_in_vector_db(data, form_data.collection_name) + + collection_name = form_data.collection_name + if collection_name == "": + collection_name = calculate_sha256_string(form_data.url)[:63] + + store_data_in_vector_db(data, collection_name) return { "status": True, - "collection_name": form_data.collection_name, + "collection_name": collection_name, "filename": form_data.url, } except Exception as e: diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 5635c57a..385a2c41 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -24,6 +24,16 @@ def calculate_sha256(file): return sha256.hexdigest() +def calculate_sha256_string(string): + # Create a new SHA-256 hash object + sha256_hash = hashlib.sha256() + # Update the hash object with the bytes of the input string + sha256_hash.update(string.encode("utf-8")) + # Get the hexadecimal representation of the hash + hashed_string = sha256_hash.hexdigest() + return hashed_string + + def validate_email_format(email: str) -> bool: if not re.match(r"[^@]+@[^@]+\.[^@]+", email): return False diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 604a5689..96d7d2e1 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -6,7 +6,7 @@ import Prompts from './MessageInput/PromptCommands.svelte'; import Suggestions from './MessageInput/Suggestions.svelte'; - import { uploadDocToVectorDB } from '$lib/apis/rag'; + import { uploadDocToVectorDB, uploadWebToVectorDB } from '$lib/apis/rag'; import AddFilesPlaceholder from '../AddFilesPlaceholder.svelte'; import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants'; import Documents from './MessageInput/Documents.svelte'; @@ -137,6 +137,33 @@ } }; + const uploadWeb = async (url) => { + console.log(url); + + const doc = { + type: 'doc', + name: url, + collection_name: '', + upload_status: false, + error: '' + }; + + try { + files = [...files, doc]; + const res = await uploadWebToVectorDB(localStorage.token, '', url); + + if (res) { + doc.upload_status = true; + doc.collection_name = res.collection_name; + files = files; + } + } catch (e) { + // Remove the failed doc from the files array + files = files.filter((f) => f.name !== url); + toast.error(e); + } + }; + onMount(() => { const dropZone = document.querySelector('body'); @@ -258,6 +285,10 @@ { + console.log(e); + uploadWeb(e.detail); + }} on:select={(e) => { console.log(e); files = [ diff --git a/src/lib/components/chat/MessageInput/Documents.svelte b/src/lib/components/chat/MessageInput/Documents.svelte index bcfb1916..5f252b3d 100644 --- a/src/lib/components/chat/MessageInput/Documents.svelte +++ b/src/lib/components/chat/MessageInput/Documents.svelte @@ -2,8 +2,9 @@ import { createEventDispatcher } from 'svelte'; import { documents } from '$lib/stores'; - import { removeFirstHashWord } from '$lib/utils'; + import { removeFirstHashWord, isValidHttpUrl } from '$lib/utils'; import { tick } from 'svelte'; + import toast from 'svelte-french-toast'; export let prompt = ''; @@ -37,9 +38,20 @@ chatInputElement?.focus(); await tick(); }; + + const confirmSelectWeb = async (url) => { + dispatch('url', url); + + prompt = removeFirstHashWord(prompt); + const chatInputElement = document.getElementById('chat-textarea'); + + await tick(); + chatInputElement?.focus(); + await tick(); + }; -{#if filteredDocs.length > 0} +{#if filteredDocs.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
@@ -55,6 +67,7 @@ : ''}" type="button" on:click={() => { + console.log(doc); confirmSelect(doc); }} on:mousemove={() => { @@ -71,6 +84,29 @@
{/each} + + {#if prompt.split(' ')?.at(0)?.substring(1).startsWith('http')} + + {/if}
diff --git a/src/lib/utils/index.ts b/src/lib/utils/index.ts index b12bdedd..16bf1cd5 100644 --- a/src/lib/utils/index.ts +++ b/src/lib/utils/index.ts @@ -212,8 +212,12 @@ const convertOpenAIMessages = (convo) => { const message = mapping[message_id]; currentId = message_id; try { - if (messages.length == 0 && (message['message'] == null || - (message['message']['content']['parts']?.[0] == '' && message['message']['content']['text'] == null))) { + if ( + messages.length == 0 && + (message['message'] == null || + (message['message']['content']['parts']?.[0] == '' && + message['message']['content']['text'] == null)) + ) { // Skip chat messages with no content continue; } else { @@ -222,7 +226,10 @@ const convertOpenAIMessages = (convo) => { parentId: lastId, childrenIds: message['children'] || [], role: message['message']?.['author']?.['role'] !== 'user' ? 'assistant' : 'user', - content: message['message']?.['content']?.['parts']?.[0] || message['message']?.['content']?.['text'] || '', + content: + message['message']?.['content']?.['parts']?.[0] || + message['message']?.['content']?.['text'] || + '', model: 'gpt-3.5-turbo', done: true, context: null @@ -231,7 +238,7 @@ const convertOpenAIMessages = (convo) => { lastId = currentId; } } catch (error) { - console.log("Error with", message, "\nError:", error); + console.log('Error with', message, '\nError:', error); } } @@ -256,31 +263,31 @@ const validateChat = (chat) => { // Because ChatGPT sometimes has features we can't use like DALL-E or migh have corrupted messages, need to validate const messages = chat.messages; - // Check if messages array is empty - if (messages.length === 0) { - return false; - } + // Check if messages array is empty + if (messages.length === 0) { + return false; + } - // Last message's children should be an empty array - const lastMessage = messages[messages.length - 1]; - if (lastMessage.childrenIds.length !== 0) { - return false; - } + // Last message's children should be an empty array + const lastMessage = messages[messages.length - 1]; + if (lastMessage.childrenIds.length !== 0) { + return false; + } - // First message's parent should be null - const firstMessage = messages[0]; - if (firstMessage.parentId !== null) { - return false; - } + // First message's parent should be null + const firstMessage = messages[0]; + if (firstMessage.parentId !== null) { + return false; + } - // Every message's content should be a string - for (let message of messages) { - if (typeof message.content !== 'string') { - return false; - } - } + // Every message's content should be a string + for (let message of messages) { + if (typeof message.content !== 'string') { + return false; + } + } - return true; + return true; }; export const convertOpenAIChats = (_chats) => { @@ -298,8 +305,22 @@ export const convertOpenAIChats = (_chats) => { chat: chat, timestamp: convo['timestamp'] }); - } else { failed ++} + } else { + failed++; + } } - console.log(failed, "Conversations could not be imported"); + console.log(failed, 'Conversations could not be imported'); return chats; }; + +export const isValidHttpUrl = (string) => { + let url; + + try { + url = new URL(string); + } catch (_) { + return false; + } + + return url.protocol === 'http:' || url.protocol === 'https:'; +};