Spaces:
Running
Running
File size: 7,335 Bytes
3b1ba42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
from smolagents import Tool
from typing import Dict, Any, Optional
import warnings
# Suppress unnecessary warnings
warnings.filterwarnings("ignore")
class TextSummarizerTool(Tool):
name = "text_summarizer"
description = """
Summarizes text using various summarization methods and models.
This tool can generate concise summaries of longer texts while preserving key information.
It supports different summarization models and customizable parameters.
"""
inputs = {
"text": {
"type": "string",
"description": "The text to be summarized",
},
"model": {
"type": "string",
"description": "Summarization model to use (default: 'facebook/bart-large-cnn')",
"nullable": True
},
"max_length": {
"type": "integer",
"description": "Maximum length of the summary in tokens (default: 130)",
"nullable": True
},
"min_length": {
"type": "integer",
"description": "Minimum length of the summary in tokens (default: 30)",
"nullable": True
},
"style": {
"type": "string",
"description": "Style of summary: 'concise', 'detailed', or 'bullet_points' (default: 'concise')",
"nullable": True
}
}
output_type = "string"
def __init__(self):
"""Initialize the Text Summarizer Tool with default settings."""
super().__init__()
self.default_model = "facebook/bart-large-cnn"
self.available_models = {
"facebook/bart-large-cnn": "BART CNN (good for news)",
"sshleifer/distilbart-cnn-12-6": "DistilBART (faster, smaller)",
"google/pegasus-xsum": "Pegasus (extreme summarization)",
"facebook/bart-large-xsum": "BART XSum (very concise)",
"philschmid/bart-large-cnn-samsum": "BART SamSum (good for conversations)"
}
# Pipeline will be lazily loaded
self._pipeline = None
def _load_pipeline(self, model_name: str):
"""Load the summarization pipeline with the specified model."""
try:
from transformers import pipeline
import torch
# Try to detect if GPU is available
device = 0 if torch.cuda.is_available() else -1
# Load the summarization pipeline
self._pipeline = pipeline(
"summarization",
model=model_name,
device=device
)
return True
except Exception as e:
print(f"Error loading model {model_name}: {str(e)}")
try:
# Fall back to default model
from transformers import pipeline
import torch
device = 0 if torch.cuda.is_available() else -1
self._pipeline = pipeline(
"summarization",
model=self.default_model,
device=device
)
return True
except Exception as fallback_error:
print(f"Error loading fallback model: {str(fallback_error)}")
return False
def _format_as_bullets(self, summary: str) -> str:
"""Format a summary as bullet points."""
# Split the summary into sentences
import re
sentences = re.split(r'(?<=[.!?])\s+', summary)
sentences = [s.strip() for s in sentences if s.strip()]
# Format as bullet points
bullet_points = []
for sentence in sentences:
# Skip very short sentences that might be artifacts
if len(sentence) < 15:
continue
bullet_points.append(f"• {sentence}")
return "\n".join(bullet_points)
def forward(self, text: str, model: str = None, max_length: int = None, min_length: int = None, style: str = None) -> str:
"""
Summarize the input text.
Args:
text: The text to summarize
model: Summarization model to use
max_length: Maximum summary length in tokens
min_length: Minimum summary length in tokens
style: Style of summary ('concise', 'detailed', or 'bullet_points')
Returns:
Summarized text
"""
# Set default values if parameters are None
if model is None:
model = self.default_model
if max_length is None:
max_length = 130
if min_length is None:
min_length = 30
if style is None:
style = "concise"
# Validate model choice
if model not in self.available_models:
return f"Model '{model}' not recognized. Available models: {', '.join(self.available_models.keys())}"
# Load the model if not already loaded or if different from current
if self._pipeline is None or (hasattr(self._pipeline, 'model') and self._pipeline.model.name_or_path != model):
if not self._load_pipeline(model):
return "Failed to load summarization model. Please try a different model."
# Adjust parameters based on style
if style == "concise":
max_length = min(100, max_length)
min_length = min(30, min_length)
elif style == "detailed":
max_length = max(150, max_length)
min_length = max(50, min_length)
# Ensure text is not too short
if len(text.split()) < 20:
return "The input text is too short to summarize effectively."
# Perform summarization
try:
# Truncate very long inputs if needed (model dependent)
max_input_length = 1024 # Most models have limits around 1024-2048 tokens
words = text.split()
if len(words) > max_input_length:
text = " ".join(words[:max_input_length])
note = "\n\nNote: The input was truncated due to length limits."
else:
note = ""
summary = self._pipeline(
text,
max_length=max_length,
min_length=min_length,
do_sample=False
)
result = summary[0]['summary_text']
# Format the result based on style
if style == "bullet_points":
result = self._format_as_bullets(result)
# Add metadata
metadata = f"\n\nSummarized using: {self.available_models.get(model, model)}"
return result + metadata + note
except Exception as e:
return f"Error summarizing text: {str(e)}"
def get_available_models(self) -> Dict[str, str]:
"""Return the dictionary of available models with descriptions."""
return self.available_models
# Example usage:
# summarizer = TextSummarizerTool()
# result = summarizer("Long text goes here...", model="facebook/bart-large-cnn", style="bullet_points")
# print(result) |