feat: web rag support

This commit is contained in:
Timothy J. Baek 2024-01-26 22:17:28 -08:00
parent 5e672d9f79
commit 28226a6f97
5 changed files with 131 additions and 33 deletions

View file

@ -37,7 +37,7 @@ from typing import Optional
import uuid
import time
from utils.misc import calculate_sha256
from utils.misc import calculate_sha256, calculate_sha256_string
from utils.utils import get_current_user
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from constants import ERROR_MESSAGES
@ -124,10 +124,15 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
try:
loader = WebBaseLoader(form_data.url)
data = loader.load()
store_data_in_vector_db(data, form_data.collection_name)
collection_name = form_data.collection_name
if collection_name == "":
collection_name = calculate_sha256_string(form_data.url)[:63]
store_data_in_vector_db(data, collection_name)
return {
"status": True,
"collection_name": form_data.collection_name,
"collection_name": collection_name,
"filename": form_data.url,
}
except Exception as e:

View file

@ -24,6 +24,16 @@ def calculate_sha256(file):
return sha256.hexdigest()
def calculate_sha256_string(string):
# Create a new SHA-256 hash object
sha256_hash = hashlib.sha256()
# Update the hash object with the bytes of the input string
sha256_hash.update(string.encode("utf-8"))
# Get the hexadecimal representation of the hash
hashed_string = sha256_hash.hexdigest()
return hashed_string
def validate_email_format(email: str) -> bool:
if not re.match(r"[^@]+@[^@]+\.[^@]+", email):
return False

View file

@ -6,7 +6,7 @@
import Prompts from './MessageInput/PromptCommands.svelte';
import Suggestions from './MessageInput/Suggestions.svelte';
import { uploadDocToVectorDB } from '$lib/apis/rag';
import { uploadDocToVectorDB, uploadWebToVectorDB } from '$lib/apis/rag';
import AddFilesPlaceholder from '../AddFilesPlaceholder.svelte';
import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants';
import Documents from './MessageInput/Documents.svelte';
@ -137,6 +137,33 @@
}
};
const uploadWeb = async (url) => {
console.log(url);
const doc = {
type: 'doc',
name: url,
collection_name: '',
upload_status: false,
error: ''
};
try {
files = [...files, doc];
const res = await uploadWebToVectorDB(localStorage.token, '', url);
if (res) {
doc.upload_status = true;
doc.collection_name = res.collection_name;
files = files;
}
} catch (e) {
// Remove the failed doc from the files array
files = files.filter((f) => f.name !== url);
toast.error(e);
}
};
onMount(() => {
const dropZone = document.querySelector('body');
@ -258,6 +285,10 @@
<Documents
bind:this={documentsElement}
bind:prompt
on:url={(e) => {
console.log(e);
uploadWeb(e.detail);
}}
on:select={(e) => {
console.log(e);
files = [

View file

@ -2,7 +2,7 @@
import { createEventDispatcher } from 'svelte';
import { documents } from '$lib/stores';
import { removeFirstHashWord } from '$lib/utils';
import { removeFirstHashWord, isValidHttpUrl } from '$lib/utils';
import { tick } from 'svelte';
export let prompt = '';
@ -37,9 +37,20 @@
chatInputElement?.focus();
await tick();
};
const confirmSelectWeb = async (url) => {
dispatch('url', url);
prompt = removeFirstHashWord(prompt);
const chatInputElement = document.getElementById('chat-textarea');
await tick();
chatInputElement?.focus();
await tick();
};
</script>
{#if filteredDocs.length > 0}
{#if filteredDocs.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
<div class="md:px-2 mb-3 text-left w-full">
<div class="flex w-full rounded-lg border border-gray-100 dark:border-gray-700">
<div class=" bg-gray-100 dark:bg-gray-700 w-10 rounded-l-lg text-center">
@ -55,6 +66,7 @@
: ''}"
type="button"
on:click={() => {
console.log(doc);
confirmSelect(doc);
}}
on:mousemove={() => {
@ -71,6 +83,25 @@
</div>
</button>
{/each}
{#if prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
<button
class="px-3 py-1.5 rounded-lg w-full text-left bg-gray-100 selected-command-option-button"
type="button"
on:click={() => {
const url = prompt.split(' ')?.at(0)?.substring(1);
if (isValidHttpUrl(url)) {
confirmSelectWeb(url);
}
}}
>
<div class=" font-medium text-black line-clamp-1">
{prompt.split(' ')?.at(0)?.substring(1)}
</div>
<div class=" text-xs text-gray-600 line-clamp-1">Web</div>
</button>
{/if}
</div>
</div>
</div>

View file

@ -212,8 +212,12 @@ const convertOpenAIMessages = (convo) => {
const message = mapping[message_id];
currentId = message_id;
try {
if (messages.length == 0 && (message['message'] == null ||
(message['message']['content']['parts']?.[0] == '' && message['message']['content']['text'] == null))) {
if (
messages.length == 0 &&
(message['message'] == null ||
(message['message']['content']['parts']?.[0] == '' &&
message['message']['content']['text'] == null))
) {
// Skip chat messages with no content
continue;
} else {
@ -222,7 +226,10 @@ const convertOpenAIMessages = (convo) => {
parentId: lastId,
childrenIds: message['children'] || [],
role: message['message']?.['author']?.['role'] !== 'user' ? 'assistant' : 'user',
content: message['message']?.['content']?.['parts']?.[0] || message['message']?.['content']?.['text'] || '',
content:
message['message']?.['content']?.['parts']?.[0] ||
message['message']?.['content']?.['text'] ||
'',
model: 'gpt-3.5-turbo',
done: true,
context: null
@ -231,7 +238,7 @@ const convertOpenAIMessages = (convo) => {
lastId = currentId;
}
} catch (error) {
console.log("Error with", message, "\nError:", error);
console.log('Error with', message, '\nError:', error);
}
}
@ -256,31 +263,31 @@ const validateChat = (chat) => {
// Because ChatGPT sometimes has features we can't use like DALL-E or migh have corrupted messages, need to validate
const messages = chat.messages;
// Check if messages array is empty
if (messages.length === 0) {
return false;
}
// Check if messages array is empty
if (messages.length === 0) {
return false;
}
// Last message's children should be an empty array
const lastMessage = messages[messages.length - 1];
if (lastMessage.childrenIds.length !== 0) {
return false;
}
// Last message's children should be an empty array
const lastMessage = messages[messages.length - 1];
if (lastMessage.childrenIds.length !== 0) {
return false;
}
// First message's parent should be null
const firstMessage = messages[0];
if (firstMessage.parentId !== null) {
return false;
}
// First message's parent should be null
const firstMessage = messages[0];
if (firstMessage.parentId !== null) {
return false;
}
// Every message's content should be a string
for (let message of messages) {
if (typeof message.content !== 'string') {
return false;
}
}
// Every message's content should be a string
for (let message of messages) {
if (typeof message.content !== 'string') {
return false;
}
}
return true;
return true;
};
export const convertOpenAIChats = (_chats) => {
@ -298,8 +305,22 @@ export const convertOpenAIChats = (_chats) => {
chat: chat,
timestamp: convo['timestamp']
});
} else { failed ++}
} else {
failed++;
}
}
console.log(failed, "Conversations could not be imported");
console.log(failed, 'Conversations could not be imported');
return chats;
};
export const isValidHttpUrl = (string) => {
let url;
try {
url = new URL(string);
} catch (_) {
return false;
}
return url.protocol === 'http:' || url.protocol === 'https:';
};