jdelavande's picture
jdelavande HF Staff
add streaming of energy, J conversion
145a107
import type { ToolResult, Tool } from "$lib/types/Tool";
import {
MessageReasoningUpdateType,
MessageUpdateType,
type MessageUpdate,
} from "$lib/types/MessageUpdate";
import { AbortedGenerations } from "../abortedGenerations";
import type { TextGenerationContext } from "./types";
import type { EndpointMessage } from "../endpoints/endpoints";
import { generateFromDefaultEndpoint } from "../generateFromDefaultEndpoint";
import { generateSummaryOfReasoning } from "./reasoning";
import { logger } from "../logger";
type GenerateContext = Omit<TextGenerationContext, "messages"> & { messages: EndpointMessage[] };
export async function* generate(
{ model, endpoint, conv, messages, assistant, isContinue, promptedAt }: GenerateContext,
toolResults: ToolResult[],
preprompt?: string,
tools?: Tool[]
): AsyncIterable<MessageUpdate> {
// reasoning mode is false by default
let reasoning = false;
let reasoningBuffer = "";
let lastReasoningUpdate = new Date();
let status = "";
const startTime = new Date();
if (
model.reasoning &&
// if the beginToken is an empty string, the model starts in reasoning mode
(model.reasoning.type === "regex" ||
model.reasoning.type === "summarize" ||
(model.reasoning.type === "tokens" && model.reasoning.beginToken === ""))
) {
// if the model has reasoning in regex or summarize mode, it starts in reasoning mode
// and we extract the answer from the reasoning
reasoning = true;
yield {
type: MessageUpdateType.Reasoning,
subtype: MessageReasoningUpdateType.Status,
status: "Started reasoning...",
};
}
for await (const output of await endpoint({
messages,
preprompt,
continueMessage: isContinue,
generateSettings: assistant?.generateSettings,
tools,
toolResults,
isMultimodal: model.multimodal,
conversationId: conv._id,
})) {
// text generation completed
if (output.generated_text) {
let interrupted =
!output.token.special && !model.parameters.stop?.includes(output.token.text);
let text = output.generated_text.trimEnd();
for (const stopToken of model.parameters.stop ?? []) {
if (!text.endsWith(stopToken)) continue;
interrupted = false;
text = text.slice(0, text.length - stopToken.length);
}
let finalAnswer = text;
if (model.reasoning && model.reasoning.type === "regex") {
const regex = new RegExp(model.reasoning.regex);
finalAnswer = regex.exec(reasoningBuffer)?.[1] ?? text;
} else if (model.reasoning && model.reasoning.type === "summarize") {
yield {
type: MessageUpdateType.Reasoning,
subtype: MessageReasoningUpdateType.Status,
status: "Summarizing reasoning...",
};
try {
const summary = yield* generateFromDefaultEndpoint({
messages: [
{
from: "user",
content: `Question: ${
messages[messages.length - 1].content
}\n\nReasoning: ${reasoningBuffer}`,
},
],
preprompt: `Your task is to summarize concisely all your reasoning steps and then give the final answer. Keep it short, one short paragraph at most. If the reasoning steps explicitly include a code solution, make sure to include it in your answer.
If the user is just having a casual conversation that doesn't require explanations, answer directly without explaining your steps, otherwise make sure to summarize step by step, make sure to skip dead-ends in your reasoning and removing excess detail.
Do not use prefixes such as Response: or Answer: when answering to the user.`,
generateSettings: {
max_new_tokens: 1024,
},
});
finalAnswer = summary;
yield {
type: MessageUpdateType.Reasoning,
subtype: MessageReasoningUpdateType.Status,
status: `Done in ${Math.round((new Date().getTime() - startTime.getTime()) / 1000)}s.`,
};
} catch (e) {
finalAnswer = text;
logger.error(e);
}
} else if (model.reasoning && model.reasoning.type === "tokens") {
// make sure to remove the content of the reasoning buffer from
// the final answer to avoid duplication
// if the beginToken is an empty string, we don't need to remove anything
const beginIndex = model.reasoning.beginToken
? reasoningBuffer.indexOf(model.reasoning.beginToken)
: 0;
const endIndex = reasoningBuffer.lastIndexOf(model.reasoning.endToken);
if (beginIndex !== -1 && endIndex !== -1) {
// Remove the reasoning section (including tokens) from final answer
finalAnswer =
text.slice(0, beginIndex) + text.slice(endIndex + model.reasoning.endToken.length);
}
}
yield {
type: MessageUpdateType.FinalAnswer,
text: finalAnswer,
interrupted,
webSources: output.webSources,
};
if (output.energy_consumption !== undefined) {
const energyUsedwh = output.energy_consumption / 1000 / 3600; // converting from mJ to Wh;
console.log("energyUsedwh", energyUsedwh);
yield {
type: MessageUpdateType.Metadata,
key: "energy_wh",
value: energyUsedwh,
};
}
continue;
}
if (model.reasoning && model.reasoning.type === "tokens") {
if (output.token.text === model.reasoning.beginToken) {
reasoning = true;
reasoningBuffer += output.token.text;
yield {
type: MessageUpdateType.Reasoning,
subtype: MessageReasoningUpdateType.Status,
status: "Started thinking...",
};
continue;
} else if (output.token.text === model.reasoning.endToken) {
reasoning = false;
reasoningBuffer += output.token.text;
yield {
type: MessageUpdateType.Reasoning,
subtype: MessageReasoningUpdateType.Status,
status: `Done in ${Math.round((new Date().getTime() - startTime.getTime()) / 1000)}s.`,
};
continue;
}
}
// ignore special tokens
if (output.token.special) continue;
// pass down normal token
if (reasoning) {
reasoningBuffer += output.token.text;
// yield status update if it has changed
if (status !== "") {
yield {
type: MessageUpdateType.Reasoning,
subtype: MessageReasoningUpdateType.Status,
status,
};
status = "";
}
// create a new status every 5 seconds
if (new Date().getTime() - lastReasoningUpdate.getTime() > 4000) {
lastReasoningUpdate = new Date();
try {
generateSummaryOfReasoning(reasoningBuffer).then((summary) => {
status = summary;
});
} catch (e) {
logger.error(e);
}
}
yield {
type: MessageUpdateType.Reasoning,
subtype: MessageReasoningUpdateType.Stream,
token: output.token.text,
};
} else {
yield { type: MessageUpdateType.Stream, token: output.token.text };
}
if (!output.token.special) {
// simulation of metadata
const durationInSeconds = (new Date().getTime() - startTime.getTime()) / 1000;
const energyUsedwh_sim = 55 * (durationInSeconds / 3600); // Using P = 50W (H100 can use up to 700W)
console.log("energyUsedwh_sim", energyUsedwh_sim);
console.log("model.name", model.name);
yield {
type: MessageUpdateType.Metadata,
key: "energy_wh_sim",
value: energyUsedwh_sim,
};
yield {
type: MessageUpdateType.Metadata,
key: "duration_seconds",
value: durationInSeconds,
};
}
// abort check
const date = AbortedGenerations.getInstance().getList().get(conv._id.toString());
if (date && date > promptedAt) break;
// no output check
if (!output) break;
}
}