tomasmcm's picture
Update TeapotAI.js
24b8349 verified
import { pipeline, env } from '@huggingface/transformers';
env.cacheDir = './.cache';
class TeapotAI {
/**
* Initializes the TeapotAI class.
* @param {object} options - Configuration options.
* @param {string} [options.modelId='teapotai/teapotllm'] - The Hugging Face model ID.
* @param {boolean} [options.verbose=true] - Whether to print status messages.
* @param {object} [options.pipelineOptions={}] - Additional pipeline options passed to the transformer.
*/
constructor({
modelId = 'tomasmcm/teapotai-teapotllm-onnx',
verbose = false,
pipelineOptions = {}
} = {}) {
this.modelId = modelId;
this.verbose = verbose;
this.pipelineOptions = pipelineOptions;
this.generator = null;
this.isInitialized = false;
if (this.verbose) {
console.log(`TeapotAI instance created for model: ${this.modelId}`);
}
}
/**
* Asynchronously initializes the text generation pipeline.
* Must be called before using generate, query, or chat.
*/
async initialize() {
if (this.isInitialized) {
if (this.verbose) console.log("Pipeline already initialized.");
return;
}
try {
if (this.verbose) console.log(`Initializing generator pipeline for model: ${this.modelId}...`);
const pipelineOptions = {
model: this.modelId,
...this.pipelineOptions
};
this.generator = await pipeline('text2text-generation', pipelineOptions.model, pipelineOptions);
this.isInitialized = true;
if (this.verbose) console.log("Pipeline initialized successfully.");
} catch (error) {
console.error("Failed to initialize pipeline:", error);
throw error;
}
}
/**
* Ensures the pipeline is initialized before proceeding.
* @private
*/
_ensureInitialized() {
if (!this.isInitialized || !this.generator) {
throw new Error("Pipeline not initialized. Call initialize() before using query(), generate(), or chat().");
}
}
/**
* Generates text based on the input string.
* (Internal method similar to the Python version's generate)
* @param {string} inputText - The text prompt to generate a response for.
* @returns {Promise<string>} The generated output from the model.
*/
async generate(inputText) {
this._ensureInitialized();
try {
if (this.verbose) console.log("Generating text...");
const output = await this.generator(inputText, {
max_new_tokens: 512,
});
const generatedText = output[0]?.generated_text?.trim() ?? "Error: Could not generate text.";
if (this.verbose) console.log("Text generation complete.");
return generatedText;
} catch (error) {
if (this.verbose) console.error("Error during text generation:", error);
return "Error: Generation failed.";
}
}
/**
* Handles a query and context to generate a response.
* (Focuses on the case where context is provided, skipping RAG)
* @param {string} query - The query string to be answered.
* @param {string} context - The context to guide the response.
* @returns {Promise<string>} The generated response based on the query and context.
*/
async query(query, context) {
this._ensureInitialized();
let inputText;
if (!context) {
if (this.verbose) console.warn("Context is empty. Proceeding without context enhancement.");
inputText = `Query: ${query}`;
} else {
inputText = `Context: ${context}\nQuery: ${query}`;
if (this.verbose) console.log("\nFormatted Input for Query:\n", inputText);
}
return this.generate(inputText);
}
/**
* Engages in a chat by taking a list of previous messages and generating a response.
* @param {Array<object>} conversationHistory - An array of message objects, each expected to have a 'content' property. E.g., [{ content: 'User: Hi' }, { content: 'Agent: Hello!' }]
* @returns {Promise<string>} The generated agent response based on the conversation history.
*/
async chat(conversationHistory) {
this._ensureInitialized();
if (!Array.isArray(conversationHistory)) {
throw new Error("conversationHistory must be an array of message objects.");
}
let chatHistoryString = conversationHistory
.map(message => message.content)
.join("\n");
const inputText = chatHistoryString + "\n" + "agent:";
if (this.verbose) console.log("\nFormatted Input for Chat:\n", inputText);
return this.generate(inputText);
}
}
export default TeapotAI;