import { defaultGenerationConfig } from "$lib/components/inference-playground/generation-config-settings.js"; // eslint-disable-next-line @typescript-eslint/ban-ts-comment // @ts-ignore - Svelte imports are broken in TS files import { showQuotaModal } from "$lib/components/quota-modal.svelte"; import { createInit } from "$lib/spells/create-init.svelte.js"; import { PipelineTag, type Conversation, type ConversationMessage, type DefaultProject, type Model, type Project, type Session, } from "$lib/types.js"; import { safeParse } from "$lib/utils/json.js"; import typia from "typia"; import { models } from "./models.svelte"; import { checkpoints } from "./checkpoints.svelte"; import { handleNonStreamingResponse, handleStreamingResponse } from "$lib/components/inference-playground/utils.js"; import { AbortManager } from "$lib/spells/abort-manager.svelte"; import { addToast } from "$lib/components/toaster.svelte.js"; import { token } from "./token.svelte"; const LOCAL_STORAGE_KEY = "hf_inference_playground_session"; interface GenerationStatistics { latency: number; generatedTokensCount: number; } const startMessageUser: ConversationMessage = { role: "user", content: "" }; const systemMessage: ConversationMessage = { role: "system", content: "", }; export const emptyModel: Model = { _id: "", inferenceProviderMapping: [], pipeline_tag: PipelineTag.TextGeneration, trendingScore: 0, tags: ["text-generation"], id: "", config: { architectures: [] as string[], model_type: "", tokenizer_config: {}, }, }; function getDefaults() { const defaultModel = models.trending[0] ?? models.remote[0] ?? emptyModel; const defaultConversation: Conversation = { model: defaultModel, config: { ...defaultGenerationConfig }, messages: [{ ...startMessageUser }], systemMessage, streaming: true, }; const defaultProject: DefaultProject = { name: "Default", id: "default", conversations: [defaultConversation], }; return { defaultProject, defaultConversation }; } class SessionState { #value = $state({} as Session); generationStats = $state([{ latency: 0, generatedTokensCount: 0 }] as | [GenerationStatistics] | [GenerationStatistics, GenerationStatistics]); generating = $state(false); #abortManager = new AbortManager(); // Call once in layout init = createInit(() => { const { defaultConversation, defaultProject } = getDefaults(); // Get saved session from localStorage if available let savedSession: Session = { projects: [defaultProject], activeProjectId: defaultProject.id, }; const savedData = localStorage.getItem(LOCAL_STORAGE_KEY); if (savedData) { const parsed = safeParse(savedData); const res = typia.validate(parsed); if (res.success) { savedSession = parsed; } else { localStorage.setItem(LOCAL_STORAGE_KEY, JSON.stringify(savedSession)); } } // Merge query params with savedSession's default project // Query params models and providers take precedence over savedSession's. // In any case, we try to merge the two, and the amount of conversations // is the maximum between the two. const dp = savedSession.projects.find(p => p.id === "default"); if (typia.is(dp)) { // Parse URL query parameters const searchParams = new URLSearchParams(window.location.search); const searchProviders = searchParams.getAll("provider"); const searchModelIds = searchParams.getAll("modelId"); const modelsFromSearch = searchModelIds.map(id => models.remote.find(model => model.id === id)).filter(Boolean); if (modelsFromSearch.length > 0) { savedSession.activeProjectId = "default"; let min = Math.min(dp.conversations.length, modelsFromSearch.length, searchProviders.length); min = Math.max(1, min); const convos = dp.conversations.slice(0, min); if (typia.is(convos)) dp.conversations = convos; for (let i = 0; i < min; i++) { const conversation = dp.conversations[i] ?? defaultConversation; dp.conversations[i] = { ...conversation, model: modelsFromSearch[i] ?? conversation.model, provider: searchProviders[i] ?? conversation.provider, }; } } } this.$ = savedSession; session.generationStats = session.project.conversations.map(_ => ({ latency: 0, generatedTokensCount: 0 })) as | [GenerationStatistics] | [GenerationStatistics, GenerationStatistics]; this.#abortManager.init(); }); constructor() { $effect.root(() => { $effect(() => { if (!this.init.called) return; const v = $state.snapshot(this.#value); try { localStorage.setItem(LOCAL_STORAGE_KEY, JSON.stringify(v)); } catch (e) { console.error("Failed to save session to localStorage:", e); } }); }); } get $() { return this.#value; } set $(v: Session) { this.#value = v; } #setAnySession(s: unknown) { if (typia.is(s)) this.$ = s; } saveProject = (args: { name: string; moveCheckpoints?: boolean }) => { const defaultProject = this.$.projects.find(p => p.id === "default"); if (!defaultProject) return; const project: Project = { ...defaultProject, name: args.name, id: crypto.randomUUID(), }; if (args.moveCheckpoints) { checkpoints.migrate(defaultProject.id, project.id); } defaultProject.conversations = [getDefaults().defaultConversation]; this.addProject(project); }; addProject = (project: Project) => { this.$ = { ...this.$, projects: [...this.$.projects, project], activeProjectId: project.id }; }; deleteProject = (id: string) => { // Can't delete default project! if (id === "default") return; const projects = this.$.projects.filter(p => p.id !== id); if (projects.length === 0) { const { defaultProject } = getDefaults(); this.#setAnySession({ ...this.$, projects: [defaultProject], activeProjectId: defaultProject.id }); } const currProject = projects.find(p => p.id === this.$.activeProjectId); this.#setAnySession({ ...this.$, projects, activeProjectId: currProject?.id ?? projects[0]?.id }); checkpoints.clear(id); }; updateProject = (id: string, data: Partial) => { const projects = this.$.projects.map(p => (p.id === id ? { ...p, ...data } : p)); this.#setAnySession({ ...this.$, projects }); }; get project() { return this.$.projects.find(p => p.id === this.$.activeProjectId) ?? this.$.projects[0]; } set project(np: Project) { const projects = this.$.projects.map(p => (p.id === np.id ? np : p)); this.#setAnySession({ ...this.$, projects }); } async #runInference(conversation: Conversation) { const idx = session.project.conversations.indexOf(conversation); const startTime = performance.now(); if (conversation.streaming) { let addedMessage = false; const streamingMessage = $state({ role: "assistant", content: "" }); await handleStreamingResponse( conversation, content => { if (!streamingMessage) return; streamingMessage.content = content; if (!addedMessage) { conversation.messages = [...conversation.messages, streamingMessage]; addedMessage = true; } }, this.#abortManager.createController() ); } else { const { message: newMessage, completion_tokens: newTokensCount } = await handleNonStreamingResponse(conversation); conversation.messages = [...conversation.messages, newMessage]; const c = session.generationStats[idx]; if (c) c.generatedTokensCount += newTokensCount; } const endTime = performance.now(); const c = session.generationStats[idx]; if (c) c.latency = Math.round(endTime - startTime); } async run(conv: "left" | "right" | "both" | Conversation = "both") { if (!token.value) { token.showModal = true; return; } const conversations = (() => { if (typeof conv === "string") { return session.project.conversations.filter((_, idx) => { return conv === "both" || (conv === "left" ? idx === 0 : idx === 1); }); } return [conv]; })(); for (let idx = 0; idx < conversations.length; idx++) { const conversation = conversations[idx]; if (!conversation || conversation.messages.at(-1)?.role !== "assistant") continue; let prefix = ""; if (session.project.conversations.length === 2) { prefix = `Error on ${idx === 0 ? "left" : "right"} conversation. `; } return addToast({ title: "Failed to run inference", description: `${prefix}Messages must alternate between user/assistant roles.`, variant: "error", }); } (document.activeElement as HTMLElement).blur(); session.generating = true; try { const promises = conversations.map(c => this.#runInference(c)); await Promise.all(promises); } catch (error) { for (const conversation of conversations) { if (conversation.messages.at(-1)?.role === "assistant" && !conversation.messages.at(-1)?.content?.trim()) { conversation.messages.pop(); conversation.messages = [...conversation.messages]; } // eslint-disable-next-line no-self-assign session.$ = session.$; } if (error instanceof Error) { const msg = error.message; if (msg.toLowerCase().includes("montly") || msg.toLowerCase().includes("pro")) { showQuotaModal(); } if (error.message.includes("token seems invalid")) { token.reset(); } if (error.name !== "AbortError") { addToast({ title: "Error", description: error.message, variant: "error" }); } } else { addToast({ title: "Error", description: "An unknown error occurred", variant: "error" }); } } finally { session.generating = false; this.#abortManager.clear(); } } stopGenerating = () => { this.#abortManager.abortAll(); session.generating = false; }; runOrStop = (c?: Parameters[0]) => { if (session.generating) { this.stopGenerating(); } else { this.run(c); } }; } export const session = new SessionState();