Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import type { ToolResult, Tool } from "$lib/types/Tool"; | |
import { | |
MessageReasoningUpdateType, | |
MessageUpdateType, | |
type MessageUpdate, | |
} from "$lib/types/MessageUpdate"; | |
import { AbortedGenerations } from "../abortedGenerations"; | |
import type { TextGenerationContext } from "./types"; | |
import type { EndpointMessage } from "../endpoints/endpoints"; | |
import { generateFromDefaultEndpoint } from "../generateFromDefaultEndpoint"; | |
import { generateSummaryOfReasoning } from "./reasoning"; | |
import { logger } from "../logger"; | |
type GenerateContext = Omit<TextGenerationContext, "messages"> & { messages: EndpointMessage[] }; | |
export async function* generate( | |
{ model, endpoint, conv, messages, assistant, isContinue, promptedAt }: GenerateContext, | |
toolResults: ToolResult[], | |
preprompt?: string, | |
tools?: Tool[] | |
): AsyncIterable<MessageUpdate> { | |
// reasoning mode is false by default | |
let reasoning = false; | |
let reasoningBuffer = ""; | |
let lastReasoningUpdate = new Date(); | |
let status = ""; | |
const startTime = new Date(); | |
if ( | |
model.reasoning && | |
// if the beginToken is an empty string, the model starts in reasoning mode | |
(model.reasoning.type === "regex" || | |
model.reasoning.type === "summarize" || | |
(model.reasoning.type === "tokens" && model.reasoning.beginToken === "")) | |
) { | |
// if the model has reasoning in regex or summarize mode, it starts in reasoning mode | |
// and we extract the answer from the reasoning | |
reasoning = true; | |
yield { | |
type: MessageUpdateType.Reasoning, | |
subtype: MessageReasoningUpdateType.Status, | |
status: "Started reasoning...", | |
}; | |
} | |
for await (const output of await endpoint({ | |
messages, | |
preprompt, | |
continueMessage: isContinue, | |
generateSettings: assistant?.generateSettings, | |
tools, | |
toolResults, | |
isMultimodal: model.multimodal, | |
conversationId: conv._id, | |
})) { | |
// text generation completed | |
if (output.generated_text) { | |
let interrupted = | |
!output.token.special && !model.parameters.stop?.includes(output.token.text); | |
let text = output.generated_text.trimEnd(); | |
for (const stopToken of model.parameters.stop ?? []) { | |
if (!text.endsWith(stopToken)) continue; | |
interrupted = false; | |
text = text.slice(0, text.length - stopToken.length); | |
} | |
let finalAnswer = text; | |
if (model.reasoning && model.reasoning.type === "regex") { | |
const regex = new RegExp(model.reasoning.regex); | |
finalAnswer = regex.exec(reasoningBuffer)?.[1] ?? text; | |
} else if (model.reasoning && model.reasoning.type === "summarize") { | |
yield { | |
type: MessageUpdateType.Reasoning, | |
subtype: MessageReasoningUpdateType.Status, | |
status: "Summarizing reasoning...", | |
}; | |
try { | |
const summary = yield* generateFromDefaultEndpoint({ | |
messages: [ | |
{ | |
from: "user", | |
content: `Question: ${ | |
messages[messages.length - 1].content | |
}\n\nReasoning: ${reasoningBuffer}`, | |
}, | |
], | |
preprompt: `Your task is to summarize concisely all your reasoning steps and then give the final answer. Keep it short, one short paragraph at most. If the reasoning steps explicitly include a code solution, make sure to include it in your answer. | |
If the user is just having a casual conversation that doesn't require explanations, answer directly without explaining your steps, otherwise make sure to summarize step by step, make sure to skip dead-ends in your reasoning and removing excess detail. | |
Do not use prefixes such as Response: or Answer: when answering to the user.`, | |
generateSettings: { | |
max_new_tokens: 1024, | |
}, | |
}); | |
finalAnswer = summary; | |
yield { | |
type: MessageUpdateType.Reasoning, | |
subtype: MessageReasoningUpdateType.Status, | |
status: `Done in ${Math.round((new Date().getTime() - startTime.getTime()) / 1000)}s.`, | |
}; | |
} catch (e) { | |
finalAnswer = text; | |
logger.error(e); | |
} | |
} else if (model.reasoning && model.reasoning.type === "tokens") { | |
// make sure to remove the content of the reasoning buffer from | |
// the final answer to avoid duplication | |
// if the beginToken is an empty string, we don't need to remove anything | |
const beginIndex = model.reasoning.beginToken | |
? reasoningBuffer.indexOf(model.reasoning.beginToken) | |
: 0; | |
const endIndex = reasoningBuffer.lastIndexOf(model.reasoning.endToken); | |
if (beginIndex !== -1 && endIndex !== -1) { | |
// Remove the reasoning section (including tokens) from final answer | |
finalAnswer = | |
text.slice(0, beginIndex) + text.slice(endIndex + model.reasoning.endToken.length); | |
} | |
} | |
yield { | |
type: MessageUpdateType.FinalAnswer, | |
text: finalAnswer, | |
interrupted, | |
webSources: output.webSources, | |
}; | |
if (output.energy_consumption !== undefined) { | |
const energyUsedwh = output.energy_consumption / 1000 / 3600; // converting from mJ to Wh; | |
console.log("energyUsedwh", energyUsedwh); | |
yield { | |
type: MessageUpdateType.Metadata, | |
key: "energy_wh", | |
value: energyUsedwh, | |
}; | |
} | |
continue; | |
} | |
if (model.reasoning && model.reasoning.type === "tokens") { | |
if (output.token.text === model.reasoning.beginToken) { | |
reasoning = true; | |
reasoningBuffer += output.token.text; | |
yield { | |
type: MessageUpdateType.Reasoning, | |
subtype: MessageReasoningUpdateType.Status, | |
status: "Started thinking...", | |
}; | |
continue; | |
} else if (output.token.text === model.reasoning.endToken) { | |
reasoning = false; | |
reasoningBuffer += output.token.text; | |
yield { | |
type: MessageUpdateType.Reasoning, | |
subtype: MessageReasoningUpdateType.Status, | |
status: `Done in ${Math.round((new Date().getTime() - startTime.getTime()) / 1000)}s.`, | |
}; | |
continue; | |
} | |
} | |
// ignore special tokens | |
if (output.token.special) continue; | |
// pass down normal token | |
if (reasoning) { | |
reasoningBuffer += output.token.text; | |
// yield status update if it has changed | |
if (status !== "") { | |
yield { | |
type: MessageUpdateType.Reasoning, | |
subtype: MessageReasoningUpdateType.Status, | |
status, | |
}; | |
status = ""; | |
} | |
// create a new status every 5 seconds | |
if (new Date().getTime() - lastReasoningUpdate.getTime() > 4000) { | |
lastReasoningUpdate = new Date(); | |
try { | |
generateSummaryOfReasoning(reasoningBuffer).then((summary) => { | |
status = summary; | |
}); | |
} catch (e) { | |
logger.error(e); | |
} | |
} | |
yield { | |
type: MessageUpdateType.Reasoning, | |
subtype: MessageReasoningUpdateType.Stream, | |
token: output.token.text, | |
}; | |
} else { | |
yield { type: MessageUpdateType.Stream, token: output.token.text }; | |
} | |
if (!output.token.special) { | |
// simulation of metadata | |
const durationInSeconds = (new Date().getTime() - startTime.getTime()) / 1000; | |
const energyUsedwh_sim = 55 * (durationInSeconds / 3600); // Using P = 50W (H100 can use up to 700W) | |
console.log("energyUsedwh_sim", energyUsedwh_sim); | |
console.log("model.name", model.name); | |
yield { | |
type: MessageUpdateType.Metadata, | |
key: "energy_wh_sim", | |
value: energyUsedwh_sim, | |
}; | |
yield { | |
type: MessageUpdateType.Metadata, | |
key: "duration_seconds", | |
value: durationInSeconds, | |
}; | |
} | |
// abort check | |
const date = AbortedGenerations.getInstance().getList().get(conv._id.toString()); | |
if (date && date > promptedAt) break; | |
// no output check | |
if (!output) break; | |
} | |
} | |