Merge pull request #333 from ollama-webui/rag

feat: RAG support
This commit is contained in:
Timothy Jaeryang Baek 2024-01-07 02:50:32 -08:00 committed by GitHub
commit 34e0f64fb3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 659 additions and 112 deletions

105
src/lib/apis/rag/index.ts Normal file
View file

@ -0,0 +1,105 @@
import { RAG_API_BASE_URL } from '$lib/constants';
export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => {
const data = new FormData();
data.append('file', file);
data.append('collection_name', collection_name);
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/doc`, {
method: 'POST',
headers: {
Accept: 'application/json',
authorization: `Bearer ${token}`
},
body: data
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const uploadWebToVectorDB = async (token: string, collection_name: string, url: string) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/web`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
url: url,
collection_name: collection_name
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const queryVectorDB = async (
token: string,
collection_name: string,
query: string,
k: number
) => {
let error = null;
const searchParams = new URLSearchParams();
searchParams.set('query', query);
if (k) {
searchParams.set('k', k.toString());
}
const res = await fetch(
`${RAG_API_BASE_URL}/query/${collection_name}/?${searchParams.toString()}`,
{
method: 'GET',
headers: {
Accept: 'application/json',
authorization: `Bearer ${token}`
}
}
)
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};

View file

@ -2,10 +2,11 @@
import toast from 'svelte-french-toast';
import { onMount, tick } from 'svelte';
import { settings } from '$lib/stores';
import { findWordIndices } from '$lib/utils';
import { calculateSHA256, findWordIndices } from '$lib/utils';
import Prompts from './MessageInput/PromptCommands.svelte';
import Suggestions from './MessageInput/Suggestions.svelte';
import { uploadDocToVectorDB } from '$lib/apis/rag';
export let submitPrompt: Function;
export let stopResponse: Function;
@ -98,7 +99,7 @@
dragged = true;
});
dropZone.addEventListener('drop', (e) => {
dropZone.addEventListener('drop', async (e) => {
e.preventDefault();
console.log(e);
@ -115,14 +116,32 @@
];
};
if (
e.dataTransfer?.files &&
e.dataTransfer?.files.length > 0 &&
['image/gif', 'image/jpeg', 'image/png'].includes(e.dataTransfer?.files[0]['type'])
) {
reader.readAsDataURL(e.dataTransfer?.files[0]);
const inputFiles = e.dataTransfer?.files;
if (inputFiles && inputFiles.length > 0) {
const file = inputFiles[0];
if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) {
reader.readAsDataURL(file);
} else if (['application/pdf', 'text/plain'].includes(file['type'])) {
console.log(file);
const hash = (await calculateSHA256(file)).substring(0, 63);
const res = await uploadDocToVectorDB(localStorage.token, hash, file);
if (res) {
files = [
...files,
{
type: 'doc',
name: file.name,
collection_name: res.collection_name
}
];
}
} else {
toast.error(`Unsupported File Type '${file['type']}'.`);
}
} else {
toast.error(`Unsupported File Type '${e.dataTransfer?.files[0]['type']}'.`);
toast.error(`File not found.`);
}
}
@ -145,11 +164,11 @@
<div class="absolute rounded-xl w-full h-full backdrop-blur bg-gray-800/40 flex justify-center">
<div class="m-auto pt-64 flex flex-col justify-center">
<div class="max-w-md">
<div class=" text-center text-6xl mb-3">🏞</div>
<div class="text-center dark:text-white text-2xl font-semibold z-50">Add Images</div>
<div class=" text-center text-6xl mb-3">🗂</div>
<div class="text-center dark:text-white text-2xl font-semibold z-50">Add Files</div>
<div class=" mt-2 text-center text-sm dark:text-gray-200 w-full">
Drop any images here to add to the conversation
Drop any files/images here to add to the conversation
</div>
</div>
</div>
@ -204,7 +223,7 @@
bind:files={inputFiles}
type="file"
hidden
on:change={() => {
on:change={async () => {
let reader = new FileReader();
reader.onload = (event) => {
files = [
@ -218,15 +237,32 @@
filesInputElement.value = '';
};
if (
inputFiles &&
inputFiles.length > 0 &&
['image/gif', 'image/jpeg', 'image/png'].includes(inputFiles[0]['type'])
) {
reader.readAsDataURL(inputFiles[0]);
if (inputFiles && inputFiles.length > 0) {
const file = inputFiles[0];
if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) {
reader.readAsDataURL(file);
} else if (['application/pdf', 'text/plain'].includes(file['type'])) {
console.log(file);
const hash = (await calculateSHA256(file)).substring(0, 63);
const res = await uploadDocToVectorDB(localStorage.token, hash, file);
if (res) {
files = [
...files,
{
type: 'doc',
name: file.name,
collection_name: res.collection_name
}
];
filesInputElement.value = '';
}
} else {
toast.error(`Unsupported File Type '${file['type']}'.`);
inputFiles = null;
}
} else {
toast.error(`Unsupported File Type '${inputFiles[0]['type']}'.`);
inputFiles = null;
toast.error(`File not found.`);
}
}}
/>
@ -237,10 +273,42 @@
}}
>
{#if files.length > 0}
<div class="ml-2 mt-2 mb-1 flex space-x-2">
<div class="mx-2 mt-2 mb-1 flex flex-wrap gap-2">
{#each files as file, fileIdx}
<div class=" relative group">
<img src={file.url} alt="input" class=" h-16 w-16 rounded-xl object-cover" />
{#if file.type === 'image'}
<img src={file.url} alt="input" class=" h-16 w-16 rounded-xl object-cover" />
{:else if file.type === 'doc'}
<div
class="h-16 w-[15rem] flex items-center space-x-3 px-2.5 dark:bg-gray-600 rounded-xl border border-gray-200 dark:border-none"
>
<div class="p-2.5 bg-red-400 text-white rounded-lg">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
class="w-6 h-6"
>
<path
fill-rule="evenodd"
d="M5.625 1.5c-1.036 0-1.875.84-1.875 1.875v17.25c0 1.035.84 1.875 1.875 1.875h12.75c1.035 0 1.875-.84 1.875-1.875V12.75A3.75 3.75 0 0 0 16.5 9h-1.875a1.875 1.875 0 0 1-1.875-1.875V5.25A3.75 3.75 0 0 0 9 1.5H5.625ZM7.5 15a.75.75 0 0 1 .75-.75h7.5a.75.75 0 0 1 0 1.5h-7.5A.75.75 0 0 1 7.5 15Zm.75 2.25a.75.75 0 0 0 0 1.5H12a.75.75 0 0 0 0-1.5H8.25Z"
clip-rule="evenodd"
/>
<path
d="M12.971 1.816A5.23 5.23 0 0 1 14.25 5.25v1.875c0 .207.168.375.375.375H16.5a5.23 5.23 0 0 1 3.434 1.279 9.768 9.768 0 0 0-6.963-6.963Z"
/>
</svg>
</div>
<div class="flex flex-col justify-center -space-y-0.5">
<div class=" dark:text-gray-100 text-sm font-medium line-clamp-1">
{file.name}
</div>
<div class=" text-gray-500 text-sm">Document</div>
</div>
</div>
{/if}
<div class=" absolute -top-1 -right-1">
<button

View file

@ -53,11 +53,41 @@
class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-headings:my-0 prose-p:my-0 prose-p:-mb-4 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-img:my-0 prose-ul:-my-4 prose-ol:-my-4 prose-li:-my-3 prose-ul:-mb-6 prose-ol:-mb-6 prose-li:-mb-4 whitespace-pre-line"
>
{#if message.files}
<div class="my-3 w-full flex overflow-x-auto space-x-2">
<div class="my-2.5 w-full flex overflow-x-auto gap-2 flex-wrap">
{#each message.files as file}
<div>
{#if file.type === 'image'}
<img src={file.url} alt="input" class=" max-h-96 rounded-lg" draggable="false" />
{:else if file.type === 'doc'}
<div
class="h-16 w-[15rem] flex items-center space-x-3 px-2.5 dark:bg-gray-600 rounded-xl border border-gray-200 dark:border-none"
>
<div class="p-2.5 bg-red-400 text-white rounded-lg">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
class="w-6 h-6"
>
<path
fill-rule="evenodd"
d="M5.625 1.5c-1.036 0-1.875.84-1.875 1.875v17.25c0 1.035.84 1.875 1.875 1.875h12.75c1.035 0 1.875-.84 1.875-1.875V12.75A3.75 3.75 0 0 0 16.5 9h-1.875a1.875 1.875 0 0 1-1.875-1.875V5.25A3.75 3.75 0 0 0 9 1.5H5.625ZM7.5 15a.75.75 0 0 1 .75-.75h7.5a.75.75 0 0 1 0 1.5h-7.5A.75.75 0 0 1 7.5 15Zm.75 2.25a.75.75 0 0 0 0 1.5H12a.75.75 0 0 0 0-1.5H8.25Z"
clip-rule="evenodd"
/>
<path
d="M12.971 1.816A5.23 5.23 0 0 1 14.25 5.25v1.875c0 .207.168.375.375.375H16.5a5.23 5.23 0 0 1 3.434 1.279 9.768 9.768 0 0 0-6.963-6.963Z"
/>
</svg>
</div>
<div class="flex flex-col justify-center -space-y-0.5">
<div class=" dark:text-gray-100 text-sm font-medium line-clamp-1">
{file.name}
</div>
<div class=" text-gray-500 text-sm">Document</div>
</div>
</div>
{/if}
</div>
{/each}

View file

@ -5,6 +5,7 @@ export const WEBUI_BASE_URL = dev ? `http://${location.hostname}:8080` : ``;
export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`;
export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`;
export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`;
export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`;
export const WEB_UI_VERSION = 'v1.0.0-alpha-static';

View file

@ -127,3 +127,37 @@ export const findWordIndices = (text) => {
return matches;
};
export const calculateSHA256 = async (file) => {
// Create a FileReader to read the file asynchronously
const reader = new FileReader();
// Define a promise to handle the file reading
const readFile = new Promise((resolve, reject) => {
reader.onload = () => resolve(reader.result);
reader.onerror = reject;
});
// Read the file as an ArrayBuffer
reader.readAsArrayBuffer(file);
try {
// Wait for the FileReader to finish reading the file
const buffer = await readFile;
// Convert the ArrayBuffer to a Uint8Array
const uint8Array = new Uint8Array(buffer);
// Calculate the SHA-256 hash using Web Crypto API
const hashBuffer = await crypto.subtle.digest('SHA-256', uint8Array);
// Convert the hash to a hexadecimal string
const hashArray = Array.from(new Uint8Array(hashBuffer));
const hashHex = hashArray.map((byte) => byte.toString(16).padStart(2, '0')).join('');
return `${hashHex}`;
} catch (error) {
console.error('Error calculating SHA-256 hash:', error);
throw error;
}
};

View file

@ -0,0 +1,20 @@
export const RAGTemplate = (context: string, query: string) => {
let template = `Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
[context]
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
Given the context information, answer the query.
Query: [query]`;
template = template.replace(/\[context\]/g, context);
template = template.replace(/\[query\]/g, query);
return template;
};

View file

@ -7,16 +7,18 @@
import { page } from '$app/stores';
import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores';
import { copyToClipboard, splitStream } from '$lib/utils';
import { generateChatCompletion, generateTitle } from '$lib/apis/ollama';
import { copyToClipboard, splitStream } from '$lib/utils';
import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats';
import { queryVectorDB } from '$lib/apis/rag';
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
import MessageInput from '$lib/components/chat/MessageInput.svelte';
import Messages from '$lib/components/chat/Messages.svelte';
import ModelSelector from '$lib/components/chat/ModelSelector.svelte';
import Navbar from '$lib/components/layout/Navbar.svelte';
import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats';
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
import { RAGTemplate } from '$lib/utils/rag';
let stopResponseFlag = false;
let autoScroll = true;
@ -113,8 +115,108 @@
// Ollama functions
//////////////////////////
const submitPrompt = async (userPrompt) => {
console.log('submitPrompt', $chatId);
if (selectedModels.includes('')) {
toast.error('Model not selected');
} else if (messages.length != 0 && messages.at(-1).done != true) {
// Response not done
console.log('wait');
} else {
// Reset chat message textarea height
document.getElementById('chat-textarea').style.height = '';
// Create user message
let userMessageId = uuidv4();
let userMessage = {
id: userMessageId,
parentId: messages.length !== 0 ? messages.at(-1).id : null,
childrenIds: [],
role: 'user',
content: userPrompt,
files: files.length > 0 ? files : undefined
};
// Add message to history and Set currentId to messageId
history.messages[userMessageId] = userMessage;
history.currentId = userMessageId;
// Append messageId to childrenIds of parent message
if (messages.length !== 0) {
history.messages[messages.at(-1).id].childrenIds.push(userMessageId);
}
// Wait until history/message have been updated
await tick();
// Create new chat if only one message in messages
if (messages.length == 1) {
if ($settings.saveChatHistory ?? true) {
chat = await createNewChat(localStorage.token, {
id: $chatId,
title: 'New Chat',
models: selectedModels,
system: $settings.system ?? undefined,
options: {
...($settings.options ?? {})
},
messages: messages,
history: history,
timestamp: Date.now()
});
await chats.set(await getChatList(localStorage.token));
await chatId.set(chat.id);
} else {
await chatId.set('local');
}
await tick();
}
// Reset chat input textarea
prompt = '';
files = [];
// Send prompt
await sendPrompt(userPrompt, userMessageId);
}
};
const sendPrompt = async (prompt, parentId) => {
const _chatId = JSON.parse(JSON.stringify($chatId));
const docs = messages
.filter((message) => message?.files ?? null)
.map((message) => message.files.filter((item) => item.type === 'doc'))
.flat(1);
console.log(docs);
if (docs.length > 0) {
const query = history.messages[parentId].content;
let relevantContexts = await Promise.all(
docs.map(async (doc) => {
return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch(
(error) => {
console.log(error);
return null;
}
);
})
);
relevantContexts = relevantContexts.filter((context) => context);
const contextString = relevantContexts.reduce((a, context, i, arr) => {
return `${a}${context.documents.join(' ')}\n`;
}, '');
console.log(contextString);
history.messages[parentId].raContent = RAGTemplate(contextString, query);
history.messages[parentId].contexts = relevantContexts;
await tick();
}
await Promise.all(
selectedModels.map(async (model) => {
console.log(model);
@ -177,7 +279,7 @@
.filter((message) => message)
.map((message) => ({
role: message.role,
content: message.content,
content: message?.raContent ?? message.content,
...(message.files && {
images: message.files
.filter((file) => file.type === 'image')
@ -366,7 +468,7 @@
content: [
{
type: 'text',
text: message.content
text: message?.raContent ?? message.content
},
...message.files
.filter((file) => file.type === 'image')
@ -378,7 +480,7 @@
}))
]
}
: { content: message.content })
: { content: message?.raContent ?? message.content })
})),
seed: $settings?.options?.seed ?? undefined,
stop: $settings?.options?.stop ?? undefined,
@ -494,73 +596,6 @@
}
};
const submitPrompt = async (userPrompt) => {
console.log('submitPrompt', $chatId);
if (selectedModels.includes('')) {
toast.error('Model not selected');
} else if (messages.length != 0 && messages.at(-1).done != true) {
// Response not done
console.log('wait');
} else {
// Reset chat message textarea height
document.getElementById('chat-textarea').style.height = '';
// Create user message
let userMessageId = uuidv4();
let userMessage = {
id: userMessageId,
parentId: messages.length !== 0 ? messages.at(-1).id : null,
childrenIds: [],
role: 'user',
content: userPrompt,
files: files.length > 0 ? files : undefined
};
// Add message to history and Set currentId to messageId
history.messages[userMessageId] = userMessage;
history.currentId = userMessageId;
// Append messageId to childrenIds of parent message
if (messages.length !== 0) {
history.messages[messages.at(-1).id].childrenIds.push(userMessageId);
}
// Wait until history/message have been updated
await tick();
// Create new chat if only one message in messages
if (messages.length == 1) {
if ($settings.saveChatHistory ?? true) {
chat = await createNewChat(localStorage.token, {
id: $chatId,
title: 'New Chat',
models: selectedModels,
system: $settings.system ?? undefined,
options: {
...($settings.options ?? {})
},
messages: messages,
history: history,
timestamp: Date.now()
});
await chats.set(await getChatList(localStorage.token));
await chatId.set(chat.id);
} else {
await chatId.set('local');
}
await tick();
}
// Reset chat input textarea
prompt = '';
files = [];
// Send prompt
await sendPrompt(userPrompt, userMessageId);
}
};
const stopResponse = () => {
stopResponseFlag = true;
console.log('stopResponse');