from smolagents import Tool from typing import Dict, Any, Optional import warnings # Suppress unnecessary warnings warnings.filterwarnings("ignore") class TextSummarizerTool(Tool): name = "text_summarizer" description = """ Summarizes text using various summarization methods and models. This tool can generate concise summaries of longer texts while preserving key information. It supports different summarization models and customizable parameters. """ inputs = { "text": { "type": "string", "description": "The text to be summarized", }, "model": { "type": "string", "description": "Summarization model to use (default: 'facebook/bart-large-cnn')", "nullable": True }, "max_length": { "type": "integer", "description": "Maximum length of the summary in tokens (default: 130)", "nullable": True }, "min_length": { "type": "integer", "description": "Minimum length of the summary in tokens (default: 30)", "nullable": True }, "style": { "type": "string", "description": "Style of summary: 'concise', 'detailed', or 'bullet_points' (default: 'concise')", "nullable": True } } output_type = "string" def __init__(self): """Initialize the Text Summarizer Tool with default settings.""" super().__init__() self.default_model = "facebook/bart-large-cnn" self.available_models = { "facebook/bart-large-cnn": "BART CNN (good for news)", "sshleifer/distilbart-cnn-12-6": "DistilBART (faster, smaller)", "google/pegasus-xsum": "Pegasus (extreme summarization)", "facebook/bart-large-xsum": "BART XSum (very concise)", "philschmid/bart-large-cnn-samsum": "BART SamSum (good for conversations)" } # Pipeline will be lazily loaded self._pipeline = None def _load_pipeline(self, model_name: str): """Load the summarization pipeline with the specified model.""" try: from transformers import pipeline import torch # Try to detect if GPU is available device = 0 if torch.cuda.is_available() else -1 # Load the summarization pipeline self._pipeline = pipeline( "summarization", model=model_name, device=device ) return True except Exception as e: print(f"Error loading model {model_name}: {str(e)}") try: # Fall back to default model from transformers import pipeline import torch device = 0 if torch.cuda.is_available() else -1 self._pipeline = pipeline( "summarization", model=self.default_model, device=device ) return True except Exception as fallback_error: print(f"Error loading fallback model: {str(fallback_error)}") return False def _format_as_bullets(self, summary: str) -> str: """Format a summary as bullet points.""" # Split the summary into sentences import re sentences = re.split(r'(?<=[.!?])\s+', summary) sentences = [s.strip() for s in sentences if s.strip()] # Format as bullet points bullet_points = [] for sentence in sentences: # Skip very short sentences that might be artifacts if len(sentence) < 15: continue bullet_points.append(f"• {sentence}") return "\n".join(bullet_points) def forward(self, text: str, model: str = None, max_length: int = None, min_length: int = None, style: str = None) -> str: """ Summarize the input text. Args: text: The text to summarize model: Summarization model to use max_length: Maximum summary length in tokens min_length: Minimum summary length in tokens style: Style of summary ('concise', 'detailed', or 'bullet_points') Returns: Summarized text """ # Set default values if parameters are None if model is None: model = self.default_model if max_length is None: max_length = 130 if min_length is None: min_length = 30 if style is None: style = "concise" # Validate model choice if model not in self.available_models: return f"Model '{model}' not recognized. Available models: {', '.join(self.available_models.keys())}" # Load the model if not already loaded or if different from current if self._pipeline is None or (hasattr(self._pipeline, 'model') and self._pipeline.model.name_or_path != model): if not self._load_pipeline(model): return "Failed to load summarization model. Please try a different model." # Adjust parameters based on style if style == "concise": max_length = min(100, max_length) min_length = min(30, min_length) elif style == "detailed": max_length = max(150, max_length) min_length = max(50, min_length) # Ensure text is not too short if len(text.split()) < 20: return "The input text is too short to summarize effectively." # Perform summarization try: # Truncate very long inputs if needed (model dependent) max_input_length = 1024 # Most models have limits around 1024-2048 tokens words = text.split() if len(words) > max_input_length: text = " ".join(words[:max_input_length]) note = "\n\nNote: The input was truncated due to length limits." else: note = "" summary = self._pipeline( text, max_length=max_length, min_length=min_length, do_sample=False ) result = summary[0]['summary_text'] # Format the result based on style if style == "bullet_points": result = self._format_as_bullets(result) # Add metadata metadata = f"\n\nSummarized using: {self.available_models.get(model, model)}" return result + metadata + note except Exception as e: return f"Error summarizing text: {str(e)}" def get_available_models(self) -> Dict[str, str]: """Return the dictionary of available models with descriptions.""" return self.available_models # Example usage: # summarizer = TextSummarizerTool() # result = summarizer("Long text goes here...", model="facebook/bart-large-cnn", style="bullet_points") # print(result)