forked from open-webui/open-webui
Merge pull request #585 from ollama-webui/web-rag
feat: web rag support
This commit is contained in:
commit
0be2803fb9
5 changed files with 136 additions and 33 deletions
|
@ -37,7 +37,7 @@ from typing import Optional
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from utils.misc import calculate_sha256
|
from utils.misc import calculate_sha256, calculate_sha256_string
|
||||||
from utils.utils import get_current_user
|
from utils.utils import get_current_user
|
||||||
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
|
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
@ -124,10 +124,15 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
||||||
try:
|
try:
|
||||||
loader = WebBaseLoader(form_data.url)
|
loader = WebBaseLoader(form_data.url)
|
||||||
data = loader.load()
|
data = loader.load()
|
||||||
store_data_in_vector_db(data, form_data.collection_name)
|
|
||||||
|
collection_name = form_data.collection_name
|
||||||
|
if collection_name == "":
|
||||||
|
collection_name = calculate_sha256_string(form_data.url)[:63]
|
||||||
|
|
||||||
|
store_data_in_vector_db(data, collection_name)
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"collection_name": form_data.collection_name,
|
"collection_name": collection_name,
|
||||||
"filename": form_data.url,
|
"filename": form_data.url,
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -24,6 +24,16 @@ def calculate_sha256(file):
|
||||||
return sha256.hexdigest()
|
return sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_sha256_string(string):
|
||||||
|
# Create a new SHA-256 hash object
|
||||||
|
sha256_hash = hashlib.sha256()
|
||||||
|
# Update the hash object with the bytes of the input string
|
||||||
|
sha256_hash.update(string.encode("utf-8"))
|
||||||
|
# Get the hexadecimal representation of the hash
|
||||||
|
hashed_string = sha256_hash.hexdigest()
|
||||||
|
return hashed_string
|
||||||
|
|
||||||
|
|
||||||
def validate_email_format(email: str) -> bool:
|
def validate_email_format(email: str) -> bool:
|
||||||
if not re.match(r"[^@]+@[^@]+\.[^@]+", email):
|
if not re.match(r"[^@]+@[^@]+\.[^@]+", email):
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import Prompts from './MessageInput/PromptCommands.svelte';
|
import Prompts from './MessageInput/PromptCommands.svelte';
|
||||||
import Suggestions from './MessageInput/Suggestions.svelte';
|
import Suggestions from './MessageInput/Suggestions.svelte';
|
||||||
import { uploadDocToVectorDB } from '$lib/apis/rag';
|
import { uploadDocToVectorDB, uploadWebToVectorDB } from '$lib/apis/rag';
|
||||||
import AddFilesPlaceholder from '../AddFilesPlaceholder.svelte';
|
import AddFilesPlaceholder from '../AddFilesPlaceholder.svelte';
|
||||||
import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants';
|
import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants';
|
||||||
import Documents from './MessageInput/Documents.svelte';
|
import Documents from './MessageInput/Documents.svelte';
|
||||||
|
@ -137,6 +137,33 @@
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const uploadWeb = async (url) => {
|
||||||
|
console.log(url);
|
||||||
|
|
||||||
|
const doc = {
|
||||||
|
type: 'doc',
|
||||||
|
name: url,
|
||||||
|
collection_name: '',
|
||||||
|
upload_status: false,
|
||||||
|
error: ''
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
files = [...files, doc];
|
||||||
|
const res = await uploadWebToVectorDB(localStorage.token, '', url);
|
||||||
|
|
||||||
|
if (res) {
|
||||||
|
doc.upload_status = true;
|
||||||
|
doc.collection_name = res.collection_name;
|
||||||
|
files = files;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
// Remove the failed doc from the files array
|
||||||
|
files = files.filter((f) => f.name !== url);
|
||||||
|
toast.error(e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
onMount(() => {
|
onMount(() => {
|
||||||
const dropZone = document.querySelector('body');
|
const dropZone = document.querySelector('body');
|
||||||
|
|
||||||
|
@ -258,6 +285,10 @@
|
||||||
<Documents
|
<Documents
|
||||||
bind:this={documentsElement}
|
bind:this={documentsElement}
|
||||||
bind:prompt
|
bind:prompt
|
||||||
|
on:url={(e) => {
|
||||||
|
console.log(e);
|
||||||
|
uploadWeb(e.detail);
|
||||||
|
}}
|
||||||
on:select={(e) => {
|
on:select={(e) => {
|
||||||
console.log(e);
|
console.log(e);
|
||||||
files = [
|
files = [
|
||||||
|
|
|
@ -2,8 +2,9 @@
|
||||||
import { createEventDispatcher } from 'svelte';
|
import { createEventDispatcher } from 'svelte';
|
||||||
|
|
||||||
import { documents } from '$lib/stores';
|
import { documents } from '$lib/stores';
|
||||||
import { removeFirstHashWord } from '$lib/utils';
|
import { removeFirstHashWord, isValidHttpUrl } from '$lib/utils';
|
||||||
import { tick } from 'svelte';
|
import { tick } from 'svelte';
|
||||||
|
import toast from 'svelte-french-toast';
|
||||||
|
|
||||||
export let prompt = '';
|
export let prompt = '';
|
||||||
|
|
||||||
|
@ -37,9 +38,20 @@
|
||||||
chatInputElement?.focus();
|
chatInputElement?.focus();
|
||||||
await tick();
|
await tick();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const confirmSelectWeb = async (url) => {
|
||||||
|
dispatch('url', url);
|
||||||
|
|
||||||
|
prompt = removeFirstHashWord(prompt);
|
||||||
|
const chatInputElement = document.getElementById('chat-textarea');
|
||||||
|
|
||||||
|
await tick();
|
||||||
|
chatInputElement?.focus();
|
||||||
|
await tick();
|
||||||
|
};
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
{#if filteredDocs.length > 0}
|
{#if filteredDocs.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
|
||||||
<div class="md:px-2 mb-3 text-left w-full">
|
<div class="md:px-2 mb-3 text-left w-full">
|
||||||
<div class="flex w-full rounded-lg border border-gray-100 dark:border-gray-700">
|
<div class="flex w-full rounded-lg border border-gray-100 dark:border-gray-700">
|
||||||
<div class=" bg-gray-100 dark:bg-gray-700 w-10 rounded-l-lg text-center">
|
<div class=" bg-gray-100 dark:bg-gray-700 w-10 rounded-l-lg text-center">
|
||||||
|
@ -55,6 +67,7 @@
|
||||||
: ''}"
|
: ''}"
|
||||||
type="button"
|
type="button"
|
||||||
on:click={() => {
|
on:click={() => {
|
||||||
|
console.log(doc);
|
||||||
confirmSelect(doc);
|
confirmSelect(doc);
|
||||||
}}
|
}}
|
||||||
on:mousemove={() => {
|
on:mousemove={() => {
|
||||||
|
@ -71,6 +84,29 @@
|
||||||
</div>
|
</div>
|
||||||
</button>
|
</button>
|
||||||
{/each}
|
{/each}
|
||||||
|
|
||||||
|
{#if prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
|
||||||
|
<button
|
||||||
|
class="px-3 py-1.5 rounded-lg w-full text-left bg-gray-100 selected-command-option-button"
|
||||||
|
type="button"
|
||||||
|
on:click={() => {
|
||||||
|
const url = prompt.split(' ')?.at(0)?.substring(1);
|
||||||
|
if (isValidHttpUrl(url)) {
|
||||||
|
confirmSelectWeb(url);
|
||||||
|
} else {
|
||||||
|
toast.error(
|
||||||
|
'Oops! Looks like the URL is invalid. Please double-check and try again.'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div class=" font-medium text-black line-clamp-1">
|
||||||
|
{prompt.split(' ')?.at(0)?.substring(1)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class=" text-xs text-gray-600 line-clamp-1">Web</div>
|
||||||
|
</button>
|
||||||
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -212,8 +212,12 @@ const convertOpenAIMessages = (convo) => {
|
||||||
const message = mapping[message_id];
|
const message = mapping[message_id];
|
||||||
currentId = message_id;
|
currentId = message_id;
|
||||||
try {
|
try {
|
||||||
if (messages.length == 0 && (message['message'] == null ||
|
if (
|
||||||
(message['message']['content']['parts']?.[0] == '' && message['message']['content']['text'] == null))) {
|
messages.length == 0 &&
|
||||||
|
(message['message'] == null ||
|
||||||
|
(message['message']['content']['parts']?.[0] == '' &&
|
||||||
|
message['message']['content']['text'] == null))
|
||||||
|
) {
|
||||||
// Skip chat messages with no content
|
// Skip chat messages with no content
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
|
@ -222,7 +226,10 @@ const convertOpenAIMessages = (convo) => {
|
||||||
parentId: lastId,
|
parentId: lastId,
|
||||||
childrenIds: message['children'] || [],
|
childrenIds: message['children'] || [],
|
||||||
role: message['message']?.['author']?.['role'] !== 'user' ? 'assistant' : 'user',
|
role: message['message']?.['author']?.['role'] !== 'user' ? 'assistant' : 'user',
|
||||||
content: message['message']?.['content']?.['parts']?.[0] || message['message']?.['content']?.['text'] || '',
|
content:
|
||||||
|
message['message']?.['content']?.['parts']?.[0] ||
|
||||||
|
message['message']?.['content']?.['text'] ||
|
||||||
|
'',
|
||||||
model: 'gpt-3.5-turbo',
|
model: 'gpt-3.5-turbo',
|
||||||
done: true,
|
done: true,
|
||||||
context: null
|
context: null
|
||||||
|
@ -231,7 +238,7 @@ const convertOpenAIMessages = (convo) => {
|
||||||
lastId = currentId;
|
lastId = currentId;
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log("Error with", message, "\nError:", error);
|
console.log('Error with', message, '\nError:', error);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -298,8 +305,22 @@ export const convertOpenAIChats = (_chats) => {
|
||||||
chat: chat,
|
chat: chat,
|
||||||
timestamp: convo['timestamp']
|
timestamp: convo['timestamp']
|
||||||
});
|
});
|
||||||
} else { failed ++}
|
} else {
|
||||||
|
failed++;
|
||||||
}
|
}
|
||||||
console.log(failed, "Conversations could not be imported");
|
}
|
||||||
|
console.log(failed, 'Conversations could not be imported');
|
||||||
return chats;
|
return chats;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const isValidHttpUrl = (string) => {
|
||||||
|
let url;
|
||||||
|
|
||||||
|
try {
|
||||||
|
url = new URL(string);
|
||||||
|
} catch (_) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return url.protocol === 'http:' || url.protocol === 'https:';
|
||||||
|
};
|
||||||
|
|
Loading…
Reference in a new issue