Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import tempfile | |
import shutil | |
import torch | |
import gradio as gr | |
from pathlib import Path | |
from typing import Optional, List, Union | |
import gc | |
import time | |
# Docling imports | |
from docling.datamodel.base_models import InputFormat | |
from docling.datamodel.pipeline_options import PdfPipelineOptions | |
from docling.document_converter import DocumentConverter, PdfFormatOption, WordFormatOption, SimplePipeline | |
# LangChain imports | |
from langchain_community.document_loaders import UnstructuredMarkdownLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.schema import Document | |
# Transformers imports for IBM Granite model | |
import spaces | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
# Initialize IBM Granite model and tokenizer | |
print("Loading Granite model and tokenizer...") | |
model_name = "ibm-granite/granite-3.3-8b-instruct" | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Create quantization config | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, # Use 4-bit quantization for better memory efficiency | |
bnb_4bit_compute_dtype=torch.bfloat16 # Use bfloat16 for computation with 4-bit quantization | |
) | |
# Load model with optimization for GPU | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
quantization_config=quantization_config | |
) | |
print("Model loaded successfully!") | |
# Helper function to detect document format | |
def get_document_format(file_path) -> Optional[InputFormat]: | |
"""Determine the document format based on file extension""" | |
try: | |
file_path = str(file_path) | |
extension = os.path.splitext(file_path)[1].lower() | |
format_map = { | |
'.pdf': InputFormat.PDF, | |
'.docx': InputFormat.DOCX, | |
'.doc': InputFormat.DOCX, | |
'.pptx': InputFormat.PPTX, | |
'.html': InputFormat.HTML, | |
'.htm': InputFormat.HTML | |
} | |
return format_map.get(extension) | |
except Exception as e: | |
print(f"Error in get_document_format: {str(e)}") | |
return None | |
# Function to convert documents to markdown | |
def convert_document_to_markdown(doc_path) -> str: | |
"""Convert document to markdown using simplified pipeline""" | |
try: | |
# Convert to absolute path string | |
input_path = os.path.abspath(str(doc_path)) | |
print(f"Converting document: {doc_path}") | |
# Create temporary directory for processing | |
with tempfile.TemporaryDirectory() as temp_dir: | |
# Copy input file to temp directory | |
temp_input = os.path.join(temp_dir, os.path.basename(input_path)) | |
shutil.copy2(input_path, temp_input) | |
# Configure pipeline options | |
pipeline_options = PdfPipelineOptions() | |
pipeline_options.do_ocr = False # Disable OCR for performance | |
pipeline_options.do_table_structure = True | |
# Create converter with optimized options | |
converter = DocumentConverter( | |
allowed_formats=[ | |
InputFormat.PDF, | |
InputFormat.DOCX, | |
InputFormat.HTML, | |
InputFormat.PPTX, | |
], | |
format_options={ | |
InputFormat.PDF: PdfFormatOption( | |
pipeline_options=pipeline_options, | |
), | |
InputFormat.DOCX: WordFormatOption( | |
pipeline_cls=SimplePipeline | |
) | |
} | |
) | |
# Convert document | |
print("Starting conversion...") | |
conv_result = converter.convert(temp_input) | |
if not conv_result or not conv_result.document: | |
raise ValueError(f"Failed to convert document: {doc_path}") | |
# Export to markdown | |
print("Exporting to markdown...") | |
md = conv_result.document.export_to_markdown() | |
# Create output path | |
output_dir = os.path.dirname(input_path) | |
base_name = os.path.splitext(os.path.basename(input_path))[0] | |
md_path = os.path.join(output_dir, f"{base_name}_converted.md") | |
# Write markdown file | |
with open(md_path, "w", encoding="utf-8") as fp: | |
fp.write(md) | |
return md_path | |
except Exception as e: | |
return f"Error converting document: {str(e)}" | |
# Improved text processing function | |
def clean_and_prepare_text(markdown_path): | |
"""Load, clean and prepare document text for better processing""" | |
try: | |
# Load the document | |
loader = UnstructuredMarkdownLoader(str(markdown_path)) | |
documents = loader.load() | |
if not documents: | |
return None, "No content could be extracted from the document." | |
# Combine all document content for pre-processing | |
raw_text = " ".join([doc.page_content for doc in documents]) | |
# Clean up the text | |
# 1. Normalize whitespace | |
text = " ".join(raw_text.split()) | |
# 2. Fix common OCR and conversion artifacts | |
text = text.replace(" .", ".").replace(" ,", ",") | |
# 3. Ensure proper spacing after punctuation | |
for punct in ['.', '!', '?']: | |
text = text.replace(f"{punct}", f"{punct} ") | |
# Split into improved documents | |
# Use a sensible paragraph size | |
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] | |
# Create structured documents for better processing | |
processed_docs = [] | |
for i, para in enumerate(paragraphs): | |
if len(para) > 10: # Skip very short paragraphs | |
processed_docs.append(Document( | |
page_content=para, | |
metadata={"source": markdown_path, "paragraph": i} | |
)) | |
return processed_docs, None | |
except Exception as e: | |
return None, f"Error processing document text: {str(e)}" | |
# Improved text splitting configuration | |
def create_optimized_text_splitter(): | |
"""Create an optimized text splitter for document processing""" | |
return RecursiveCharacterTextSplitter( | |
chunk_size=800, # Slightly smaller for more focused chunks | |
chunk_overlap=150, # Increased overlap to maintain context | |
length_function=len, | |
separators=["\n\n", "\n", ".", "!", "?", ";", ":", " ", ""] # More comprehensive separators | |
) | |
# Function to generate a summary using the IBM Granite model | |
def generate_summary(chunks: List[Document], length_type="sentences", length_count=3): | |
"""Generate a summary from document chunks using the IBM Granite model | |
Args: | |
chunks: List of document chunks to summarize | |
length_type: Either "sentences" or "paragraphs" | |
length_count: Number of sentences (1-10) or paragraphs (1-3) | |
""" | |
# Print debug information | |
print(f"Generating summary with length_type={length_type}, length_count={length_count}") | |
# Ensure length_count is an integer | |
try: | |
length_count = int(length_count) | |
except (ValueError, TypeError): | |
print(f"Failed to convert length_count to int: {length_count}, using default 3") | |
length_count = 3 | |
# Apply limits based on type | |
if length_type == "sentences": | |
length_count = max(1, min(10, length_count)) # Limit to 1-10 sentences | |
else: # paragraphs | |
length_count = max(1, min(3, length_count)) # Limit to 1-3 paragraphs | |
# Clean and concatenate the text from chunks | |
# Remove any excessive whitespace and normalize | |
cleaned_chunks = [] | |
for chunk in chunks: | |
text = chunk.page_content | |
# Remove excessive newlines and whitespace | |
text = ' '.join(text.split()) | |
cleaned_chunks.append(text) | |
combined_text = " ".join(cleaned_chunks) | |
# More explicit and forceful prompt structure | |
if length_type == "sentences": | |
length_instruction = f"Create a concise summary that is EXACTLY {length_count} complete sentences. Not {length_count-1} sentences. Not {length_count+1} sentences. EXACTLY {length_count} sentences." | |
else: # paragraphs | |
length_instruction = f"Create a concise summary that is EXACTLY {length_count} paragraphs. Each paragraph should be 2-4 sentences long. Not {length_count-1} paragraphs. Not {length_count+1} paragraphs. EXACTLY {length_count} paragraphs." | |
# More detailed prompt with examples of what constitutes a sentence | |
prompt = f"""<instruction> | |
You are an expert document summarizer. Your task is to create a high-quality summary of the following text. | |
{length_instruction} | |
Remember: | |
- Your summary must capture the main points of the document | |
- Your summary must be in your own words (not copied text) | |
- Your summary must be clearly written and well-structured | |
- Do not include any explanations, headings, bullet points, or additional formatting | |
- Respond ONLY with the summary text itself | |
</instruction> | |
<text> | |
{combined_text} | |
</text> | |
""" | |
# Calculate appropriate max_new_tokens but with stricter limits | |
if length_type == "sentences": | |
# Approximately 20 tokens per sentence | |
max_tokens = length_count * 40 | |
else: # paragraphs | |
# Approximately 100 tokens per paragraph | |
max_tokens = length_count * 150 | |
# Ensure minimum tokens and add buffer | |
max_tokens = max(100, min(1500, max_tokens)) | |
print(f"Using max_new_tokens={max_tokens}") | |
# Generate with lower temperature for more consistent results | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
output = model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
temperature=0.3, # Lower temperature for more deterministic output | |
top_p=0.9, | |
do_sample=True, | |
repetition_penalty=1.2 # Discourage repetition | |
) | |
# Decode and return the generated summary | |
summary = tokenizer.decode(output[0], skip_special_tokens=True) | |
# Extract just the generated response (after the prompt) | |
summary = summary[len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)):] | |
summary = summary.strip() | |
# Post-process the summary to ensure it meets the length constraints | |
if length_type == "sentences": | |
# Simple sentence counting based on periods | |
sentences = [s.strip() for s in summary.split('.') if s.strip()] | |
if len(sentences) > length_count: | |
# Take only the requested number of sentences | |
summary = '. '.join(sentences[:length_count]) + '.' | |
elif len(sentences) < length_count: | |
# If we have too few sentences, log this issue | |
print(f"Warning: Generated only {len(sentences)} sentences instead of {length_count}") | |
return summary.strip() | |
# Function to process document chunks efficiently | |
def process_document_chunks(texts, batch_size=8): | |
"""Process document chunks in efficient batches""" | |
try: | |
# Create embeddings with optimized settings | |
embeddings = HuggingFaceEmbeddings( | |
model_name="nomic-ai/nomic-embed-text-v1", | |
model_kwargs={'trust_remote_code': True} | |
) | |
# Create vector store more efficiently | |
vectorstore = FAISS.from_documents( | |
texts, | |
embeddings, | |
# Add distance function for better retrieval | |
distance_strategy="cosine" | |
) | |
return vectorstore | |
except Exception as e: | |
print(f"Error in document processing: {str(e)}") | |
# Fallback to basic processing if optimization fails | |
embeddings = HuggingFaceEmbeddings( | |
model_name="nomic-ai/nomic-embed-text-v1", | |
model_kwargs={'trust_remote_code': True} | |
) | |
return FAISS.from_documents(texts, embeddings) | |
# Main function to process document and generate summary | |
def process_document( | |
file_obj: Optional[Union[str, tempfile._TemporaryFileWrapper]] = None, | |
length_type: str = "sentences", | |
length_count: int = 3, | |
progress=gr.Progress() | |
): | |
"""Process a document file and generate a summary""" | |
try: | |
# Process input file | |
if not file_obj: | |
return "Please provide a file to summarize." | |
document_path = file_obj.name if hasattr(file_obj, 'name') else str(file_obj) | |
# Validate document format | |
format_type = get_document_format(document_path) | |
if not format_type: | |
return "Unsupported file format. Please upload a PDF, DOCX, PPTX, or HTML file." | |
# Convert document to markdown | |
progress(0.3, "Converting document to markdown...") | |
markdown_path = convert_document_to_markdown(document_path) | |
if markdown_path.startswith("Error"): | |
return markdown_path | |
# Clean and prepare the text | |
progress(0.4, "Processing document text...") | |
processed_docs, error = clean_and_prepare_text(markdown_path) | |
if error: | |
return error | |
# Split the documents with optimized splitter | |
text_splitter = create_optimized_text_splitter() | |
texts = text_splitter.split_documents(processed_docs) | |
if not texts: | |
return "No text could be extracted from the document." | |
# Create vector store with efficient processing | |
progress(0.6, "Processing document content...") | |
vectorstore = process_document_chunks(texts) | |
# Create retriever with optimized settings | |
retriever = vectorstore.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 4} # Number of chunks to retrieve | |
) | |
# Process chunks in smaller batches for memory efficiency | |
progress(0.8, "Generating summary...") | |
all_chunks = [] | |
batch_size = 4 # Smaller batch size for memory efficiency | |
# Get all document chunks | |
doc_ids = list(vectorstore.index_to_docstore_id.values()) | |
# Process in smaller batches | |
for i in range(0, len(doc_ids), batch_size): | |
batch_ids = doc_ids[i:i+batch_size] | |
batch_chunks = [vectorstore.docstore.search(doc_id) for doc_id in batch_ids] | |
all_chunks.extend(batch_chunks) | |
# Force garbage collection to free memory | |
gc.collect() | |
# Sleep briefly to allow memory cleanup | |
time.sleep(0.1) | |
# Case 1: Very small documents - use all chunks directly | |
if len(all_chunks) <= 8: | |
return generate_summary( | |
all_chunks, | |
length_type=length_type.lower(), | |
length_count=length_count | |
) | |
# Case 2: Medium-sized documents - process in one batch | |
elif len(all_chunks) <= 16: | |
return generate_summary( | |
all_chunks[:8], # Use first 8 chunks (usually contains most important info) | |
length_type=length_type.lower(), | |
length_count=length_count | |
) | |
# Case 3: Large documents - process in multiple batches | |
else: | |
# First pass: Generate summaries for each batch | |
summaries = [] | |
for i in range(0, len(all_chunks), batch_size): | |
batch = all_chunks[i:i+batch_size] | |
summary = generate_summary( | |
batch, | |
length_type="paragraphs", # Use paragraphs for intermediate summaries | |
length_count=1 # One paragraph per batch | |
) | |
summaries.append(summary) | |
# Force garbage collection | |
gc.collect() | |
# Second pass: Generate final summary from batch summaries | |
final_summary = generate_summary( | |
[Document(page_content=s) for s in summaries], | |
length_type=length_type.lower(), | |
length_count=length_count | |
) | |
return final_summary | |
except Exception as e: | |
return f"Error processing document: {str(e)}" | |
# Create Gradio interface | |
def create_gradio_interface(): | |
"""Create and launch the Gradio interface""" | |
with gr.Blocks(title="Granite Document Summarization") as app: | |
gr.Markdown("# Granite Document Summarization") | |
gr.Markdown("Upload a document to generate a summary.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
file_input = gr.File( | |
label="Upload Document (PDF, DOCX, PPTX, HTML)", | |
file_types=[".pdf", ".docx", ".doc", ".pptx", ".html", ".htm"] | |
) | |
with gr.Row(): | |
length_type = gr.Radio( | |
choices=["Sentences", "Paragraphs"], | |
value="Sentences", | |
label="Summary Length Type" | |
) | |
with gr.Row(): | |
# Use slider for sentence count (1-10) | |
sentence_count = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=3, | |
step=1, | |
label="Number of Sentences", | |
visible=True | |
) | |
# Use radio for paragraph count (1-3) | |
paragraph_count = gr.Radio( | |
choices=["1", "2", "3"], | |
value="1", | |
label="Number of Paragraphs", | |
visible=False | |
) | |
submit_btn = gr.Button("Summarize", variant="primary") | |
with gr.Column(scale=2): | |
output = gr.TextArea( | |
label="Summary", | |
lines=15, | |
max_lines=30 | |
) | |
# Add interactivity to show/hide appropriate count selector | |
def update_count_visibility(length_type): | |
is_sentences = length_type == "Sentences" | |
return [ | |
gr.update(visible=is_sentences), # For sentence_count | |
gr.update(visible=not is_sentences) # For paragraph_count | |
] | |
length_type.change( | |
fn=update_count_visibility, | |
inputs=[length_type], | |
outputs=[sentence_count, paragraph_count] | |
) | |
# Function to handle form submission properly | |
def process_document_wrapper(file, length_type, sentence_count, paragraph_count): | |
# Convert capitalized length_type to lowercase for processing | |
length_type_lower = length_type.lower() | |
print(f"Processing with length_type={length_type}, sentence_count={sentence_count}, paragraph_count={paragraph_count}") | |
# Determine count based on the selected length type | |
if length_type_lower == "sentences": | |
# For sentences, use the slider value directly | |
try: | |
count = int(sentence_count) | |
count = max(1, min(10, count)) # Ensure within range 1-10 | |
print(f"Using sentence count: {count}") | |
except (ValueError, TypeError): | |
print(f"Invalid sentence count: {sentence_count}, using default 3") | |
count = 3 | |
else: | |
# For paragraphs, convert from string to int if needed | |
try: | |
# Check if paragraph_count is a string (from radio button) | |
if isinstance(paragraph_count, str): | |
count = int(paragraph_count) | |
# Check if it's a boolean (from visibility toggle) | |
elif isinstance(paragraph_count, bool): | |
count = 1 # Default if boolean | |
else: | |
count = int(paragraph_count) | |
count = max(1, min(3, count)) # Ensure within range 1-3 | |
print(f"Using paragraph count: {count}") | |
except (ValueError, TypeError): | |
print(f"Invalid paragraph count: {paragraph_count}, using default 1") | |
count = 1 | |
return process_document(file, length_type_lower, count) | |
submit_btn.click( | |
fn=process_document_wrapper, | |
inputs=[file_input, length_type, sentence_count, paragraph_count], | |
outputs=output | |
) | |
gr.Markdown(""" | |
## How to use: | |
1. Upload a document (PDF, DOCX, PPTX, HTML) | |
2. Choose your summary length preference: | |
- Number of Sentences (1-10) | |
- Number of Paragraphs (1-3) | |
3. Click "Summarize" to process the document | |
*This application uses the IBM Granite 3.3-8b model to generate summaries.* | |
""") | |
return app | |
# Launch the application | |
if __name__ == "__main__": | |
app = create_gradio_interface() | |
app.launch() |