Merge pull request #407 from anuraagdjain/feat/parallel-model-downloads

feat: parallel model downloads
This commit is contained in:
Timothy Jaeryang Baek 2024-01-11 12:53:21 -08:00 committed by GitHub
commit ed4b3e0b32
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 152 additions and 552 deletions

View file

@ -299,13 +299,30 @@ export const pullModel = async (token: string, tagName: string) => {
name: tagName
})
}).catch((err) => {
console.log(err);
error = err;
if ('detail' in err) {
error = err.detail;
}
return null;
});
if (error) {
throw error;
}
return res;
};
// export const pullModel = async (token: string, tagName: string) => {
// return await fetch(`${OLLAMA_API_BASE_URL}/pull`, {
// method: 'POST',
// headers: {
// 'Content-Type': 'text/event-stream',
// Authorization: `Bearer ${token}`
// },
// body: JSON.stringify({
// name: tagName
// })
// });
// };

View file

@ -1,11 +1,11 @@
<script lang="ts">
import toast from 'svelte-french-toast';
import queue from 'async/queue';
import fileSaver from 'file-saver';
const { saveAs } = fileSaver;
import { goto } from '$app/navigation';
import { onMount } from 'svelte';
import { config, models, settings, user, chats } from '$lib/stores';
import { splitStream, getGravatarURL } from '$lib/utils';
import {
getOllamaVersion,
@ -16,14 +16,16 @@
createModel,
deleteModel
} from '$lib/apis/ollama';
import { updateUserPassword } from '$lib/apis/auths';
import { createNewChat, deleteAllChats, getAllChats, getChatList } from '$lib/apis/chats';
import { WEB_UI_VERSION, WEBUI_API_BASE_URL } from '$lib/constants';
import { config, models, settings, user, chats } from '$lib/stores';
import { splitStream, getGravatarURL } from '$lib/utils';
import Advanced from './Settings/Advanced.svelte';
import Modal from '../common/Modal.svelte';
import { updateUserPassword } from '$lib/apis/auths';
import { goto } from '$app/navigation';
import Page from '../../../routes/(app)/+page.svelte';
import {
getOpenAIKey,
getOpenAIModels,
@ -71,8 +73,15 @@
};
// Models
let modelTransferring = false;
const MAX_PARALLEL_DOWNLOADS = 3;
const modelDownloadQueue = queue(
(task: { modelName: string }, cb) =>
pullModelHandlerProcessor({ modelName: task.modelName, callback: cb }),
MAX_PARALLEL_DOWNLOADS
);
let modelDownloadStatus: Record<string, any> = {};
let modelTransferring = false;
let modelTag = '';
let digest = '';
let pullProgress = null;
@ -87,7 +96,6 @@
let deleteModelTag = '';
// External
let OPENAI_API_KEY = '';
let OPENAI_API_BASE_URL = '';
@ -104,6 +112,32 @@
let importFiles;
let showDeleteConfirm = false;
// Auth
let authEnabled = false;
let authType = 'Basic';
let authContent = '';
// Account
let currentPassword = '';
let newPassword = '';
let newPasswordConfirm = '';
// About
let ollamaVersion = '';
$: if (importFiles) {
console.log(importFiles);
let reader = new FileReader();
reader.onload = (event) => {
let chats = JSON.parse(event.target.result);
console.log(chats);
importChats(chats);
};
reader.readAsText(importFiles[0]);
}
const importChats = async (_chats) => {
for (const chat of _chats) {
console.log(chat);
@ -120,38 +154,12 @@
saveAs(blob, `chat-export-${Date.now()}.json`);
};
$: if (importFiles) {
console.log(importFiles);
let reader = new FileReader();
reader.onload = (event) => {
let chats = JSON.parse(event.target.result);
console.log(chats);
importChats(chats);
};
reader.readAsText(importFiles[0]);
}
const deleteChats = async () => {
await goto('/');
await deleteAllChats(localStorage.token);
await chats.set(await getChatList(localStorage.token));
};
// Auth
let authEnabled = false;
let authType = 'Basic';
let authContent = '';
// Account
let currentPassword = '';
let newPassword = '';
let newPasswordConfirm = '';
// About
let ollamaVersion = '';
const updateOllamaAPIUrlHandler = async () => {
API_BASE_URL = await updateOllamaAPIUrl(localStorage.token, API_BASE_URL);
const _models = await getModels('ollama');
@ -247,10 +255,11 @@
saveSettings({ saveChatHistory: saveChatHistory });
};
const pullModelHandler = async () => {
modelTransferring = true;
const res = await pullModel(localStorage.token, modelTag);
const pullModelHandlerProcessor = async (opts: { modelName: string; callback: Function }) => {
const res = await pullModel(localStorage.token, opts.modelName).catch((error) => {
opts.callback({ success: false, error, modelName: opts.modelName });
return null;
});
if (res) {
const reader = res.body
@ -259,92 +268,89 @@
.getReader();
while (true) {
const { value, done } = await reader.read();
if (done) break;
try {
const { value, done } = await reader.read();
if (done) break;
let lines = value.split('\n');
for (const line of lines) {
if (line !== '') {
console.log(line);
let data = JSON.parse(line);
console.log(data);
if (data.error) {
throw data.error;
}
if (data.detail) {
throw data.detail;
}
if (data.status) {
if (!data.digest) {
toast.success(data.status);
if (data.status === 'success') {
const notification = new Notification(`Ollama`, {
body: `Model '${modelTag}' has been successfully downloaded.`,
icon: '/favicon.png'
});
}
} else {
digest = data.digest;
if (data.digest) {
let downloadProgress = 0;
if (data.completed) {
pullProgress = Math.round((data.completed / data.total) * 1000) / 10;
downloadProgress = Math.round((data.completed / data.total) * 1000) / 10;
} else {
pullProgress = 100;
downloadProgress = 100;
}
modelDownloadStatus[opts.modelName] = {
pullProgress: downloadProgress,
digest: data.digest
};
} else {
toast.success(data.status);
}
}
}
}
} catch (error) {
console.log(error);
toast.error(error);
if (typeof error !== 'string') {
error = error.message;
}
opts.callback({ success: false, error, modelName: opts.modelName });
}
}
opts.callback({ success: true, modelName: opts.modelName });
}
};
const pullModelHandler = async () => {
if (modelDownloadStatus[modelTag]) {
toast.error(`Model '${modelTag}' is already in queue for downloading.`);
return;
}
if (Object.keys(modelDownloadStatus).length === 3) {
toast.error('Maximum of 3 models can be downloaded simultaneously. Please try again later.');
return;
}
modelTransferring = true;
modelDownloadQueue.push(
{ modelName: modelTag },
async (data: { modelName: string; success: boolean; error?: Error }) => {
const { modelName } = data;
// Remove the downloaded model
delete modelDownloadStatus[modelName];
console.log(data);
if (!data.success) {
toast.error(data.error);
} else {
toast.success(`Model '${modelName}' has been successfully downloaded.`);
const notification = new Notification(`Ollama`, {
body: `Model '${modelName}' has been successfully downloaded.`,
icon: '/favicon.png'
});
models.set(await getModels());
}
}
);
modelTag = '';
modelTransferring = false;
models.set(await getModels());
};
const calculateSHA256 = async (file) => {
console.log(file);
// Create a FileReader to read the file asynchronously
const reader = new FileReader();
// Define a promise to handle the file reading
const readFile = new Promise((resolve, reject) => {
reader.onload = () => resolve(reader.result);
reader.onerror = reject;
});
// Read the file as an ArrayBuffer
reader.readAsArrayBuffer(file);
try {
// Wait for the FileReader to finish reading the file
const buffer = await readFile;
// Convert the ArrayBuffer to a Uint8Array
const uint8Array = new Uint8Array(buffer);
// Calculate the SHA-256 hash using Web Crypto API
const hashBuffer = await crypto.subtle.digest('SHA-256', uint8Array);
// Convert the hash to a hexadecimal string
const hashArray = Array.from(new Uint8Array(hashBuffer));
const hashHex = hashArray.map((byte) => byte.toString(16).padStart(2, '0')).join('');
return `sha256:${hashHex}`;
} catch (error) {
console.error('Error calculating SHA-256 hash:', error);
throw error;
}
};
const uploadModelHandler = async () => {
@ -1158,7 +1164,7 @@
</button>
</div>
<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
To access the available model names for downloading, <a
class=" text-gray-500 dark:text-gray-300 font-medium"
href="https://ollama.ai/library"
@ -1166,23 +1172,29 @@
>
</div>
{#if pullProgress !== null}
<div class="mt-2">
<div class=" mb-2 text-xs">Pull Progress</div>
<div class="w-full rounded-full dark:bg-gray-800">
<div
class="dark:bg-gray-600 bg-gray-500 text-xs font-medium text-gray-100 text-center p-0.5 leading-none rounded-full"
style="width: {Math.max(15, pullProgress ?? 0)}%"
>
{pullProgress ?? 0}%
{#if Object.keys(modelDownloadStatus).length > 0}
{#each Object.keys(modelDownloadStatus) as model}
<div class="flex flex-col">
<div class="font-medium mb-1">{model}</div>
<div class="">
<div
class="dark:bg-gray-600 bg-gray-500 text-xs font-medium text-gray-100 text-center p-0.5 leading-none rounded-full"
style="width: {Math.max(
15,
modelDownloadStatus[model].pullProgress ?? 0
)}%"
>
{modelDownloadStatus[model].pullProgress ?? 0}%
</div>
<div class="mt-1 text-xs dark:text-gray-500" style="font-size: 0.5rem;">
{modelDownloadStatus[model].digest}
</div>
</div>
</div>
<div class="mt-1 text-xs dark:text-gray-500" style="font-size: 0.5rem;">
{digest}
</div>
</div>
{/each}
{/if}
</div>
<hr class=" dark:border-gray-700" />
<div>