feat: parallel model downloads

This commit is contained in:
Anuraag Jain 2024-01-06 12:10:41 +02:00
parent cb93038abf
commit ea721feea9
3 changed files with 79 additions and 13 deletions

11
package-lock.json generated
View file

@ -9,6 +9,7 @@
"version": "0.0.1",
"dependencies": {
"@sveltejs/adapter-node": "^1.3.1",
"async": "^3.2.5",
"file-saver": "^2.0.5",
"highlight.js": "^11.9.0",
"idb": "^7.1.1",
@ -1208,6 +1209,11 @@
"node": ">=8"
}
},
"node_modules/async": {
"version": "3.2.5",
"resolved": "https://registry.npmjs.org/async/-/async-3.2.5.tgz",
"integrity": "sha512-baNZyqaaLhyLVKm/DlvdW051MSgO6b8eVfIezl9E5PqWxFgzLm/wQntEW4zOytVburDEr0JlALEpdOFwvErLsg=="
},
"node_modules/autoprefixer": {
"version": "10.4.16",
"resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.16.tgz",
@ -4645,6 +4651,11 @@
"integrity": "sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==",
"dev": true
},
"async": {
"version": "3.2.5",
"resolved": "https://registry.npmjs.org/async/-/async-3.2.5.tgz",
"integrity": "sha512-baNZyqaaLhyLVKm/DlvdW051MSgO6b8eVfIezl9E5PqWxFgzLm/wQntEW4zOytVburDEr0JlALEpdOFwvErLsg=="
},
"autoprefixer": {
"version": "10.4.16",
"resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.16.tgz",

View file

@ -39,6 +39,7 @@
"type": "module",
"dependencies": {
"@sveltejs/adapter-node": "^1.3.1",
"async": "^3.2.5",
"file-saver": "^2.0.5",
"highlight.js": "^11.9.0",
"idb": "^7.1.1",

View file

@ -6,6 +6,7 @@
import { onMount } from 'svelte';
import { config, models, settings, user, chats } from '$lib/stores';
import { splitStream, getGravatarURL } from '$lib/utils';
import queue from 'async/queue';
import { getOllamaVersion } from '$lib/apis/ollama';
import { createNewChat, deleteAllChats, getAllChats, getChatList } from '$lib/apis/chats';
@ -38,6 +39,8 @@
let theme = 'dark';
let notificationEnabled = false;
let system = '';
const modelDownloadQueue = queue((task:{modelName: string}, cb) => pullModelHandlerProcessor({modelName: task.modelName, callback: cb}), 3);
let modelDownloadStatus: Record<string, any> = {};
// Advanced
let requestFormat = '';
@ -224,8 +227,9 @@
authEnabled = !authEnabled;
};
const pullModelHandler = async () => {
modelTransferring = true;
const pullModelHandlerProcessor = async (opts:{modelName:string, callback: Function}) => {
console.log('Pull model name', opts.modelName);
const res = await fetch(`${API_BASE_URL}/pull`, {
method: 'POST',
headers: {
@ -234,7 +238,7 @@
...($user && { Authorization: `Bearer ${localStorage.token}` })
},
body: JSON.stringify({
name: modelTag
name: opts.modelName
})
});
@ -265,11 +269,9 @@
}
if (data.status) {
if (!data.digest) {
toast.success(data.status);
if (data.status === 'success') {
const notification = new Notification(`Ollama`, {
body: `Model '${modelTag}' has been successfully downloaded.`,
body: `Model '${opts.modelName}' has been successfully downloaded.`,
icon: '/favicon.png'
});
}
@ -280,21 +282,48 @@
} else {
pullProgress = 100;
}
modelDownloadStatus[opts.modelName] = {pullProgress};
}
}
}
}
} catch (error) {
console.log(error);
toast.error(error);
console.error(error);
opts.callback({success:false, error, modelName: opts.modelName});
}
}
opts.callback({success: true, modelName: opts.modelName});
};
const pullModelHandler = async() => {
if(modelDownloadStatus[modelTag]){
toast.error("Model already in queue for downloading.");
return;
}
if(Object.keys(modelDownloadStatus).length === 3){
toast.error('Maximum of 3 models can be downloading simultaneously. Please try again later');
return;
}
modelTransferring = true;
modelDownloadQueue.push({modelName: modelTag},async (data:{modelName: string; success: boolean; error?: Error}) => {
const {modelName} = data;
// Remove the downloaded model
delete modelDownloadStatus[modelName];
if(!data.success){
toast.error(`There was some issue in downloading the model ${modelName}`);
return;
}
toast.success(`Model ${modelName} was successfully downloaded`);
models.set(await getModels());
});
modelTag = '';
modelTransferring = false;
modelTransferring = false;
}
models.set(await getModels());
};
const calculateSHA256 = async (file) => {
console.log(file);
@ -1248,7 +1277,7 @@
>
</div>
{#if pullProgress !== null}
<!-- {#if pullProgress !== null}
<div class="mt-2">
<div class=" mb-2 text-xs">Pull Progress</div>
<div class="w-full rounded-full dark:bg-gray-800">
@ -1263,8 +1292,33 @@
{digest}
</div>
</div>
{/if}
{/if} -->
</div>
{#if Object.keys(modelDownloadStatus).length > 0}
<table class="w-full text-sm text-left text-gray-500 dark:text-gray-400">
<thead
class="text-xs text-gray-700 uppercase bg-gray-50 dark:bg-gray-700 dark:text-gray-400"
>
<tr>
<th scope="col" class="px-6 py-3"> Model Name </th>
<th scope="col" class="px-6 py-3"> Download progress </th>
</tr>
</thead>
<tbody>
{#each Object.entries(modelDownloadStatus) as [modelName, payload]}
<tr class="bg-white border-b dark:bg-gray-800 dark:border-gray-700">
<td class="px-6 py-4">{modelName}</td>
<td class="px-6 py-4"><div
class="dark:bg-gray-600 text-xs font-medium text-blue-100 text-center p-0.5 leading-none rounded-full"
style="width: {Math.max(15, payload.pullProgress ?? 0)}%"
>
{ payload.pullProgress ?? 0}%
</div></td>
</tr>
{/each}
</tbody>
</table>
{/if}
<hr class=" dark:border-gray-700" />
<div>