import os
import tempfile
import shutil
import torch
import gradio as gr
from pathlib import Path
from typing import Optional, List, Union
import gc
import time
# Docling imports
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.document_converter import DocumentConverter, PdfFormatOption, WordFormatOption, SimplePipeline
# LangChain imports
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
# Transformers imports for IBM Granite model
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# Initialize IBM Granite model and tokenizer
print("Loading Granite model and tokenizer...")
model_name = "ibm-granite/granite-3.3-8b-instruct"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Create quantization config
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Use 4-bit quantization for better memory efficiency
bnb_4bit_compute_dtype=torch.bfloat16 # Use bfloat16 for computation with 4-bit quantization
)
# Load model with optimization for GPU
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=quantization_config
)
print("Model loaded successfully!")
# Helper function to detect document format
def get_document_format(file_path) -> Optional[InputFormat]:
"""Determine the document format based on file extension"""
try:
file_path = str(file_path)
extension = os.path.splitext(file_path)[1].lower()
format_map = {
'.pdf': InputFormat.PDF,
'.docx': InputFormat.DOCX,
'.doc': InputFormat.DOCX,
'.pptx': InputFormat.PPTX,
'.html': InputFormat.HTML,
'.htm': InputFormat.HTML
}
return format_map.get(extension)
except Exception as e:
print(f"Error in get_document_format: {str(e)}")
return None
# Function to convert documents to markdown
def convert_document_to_markdown(doc_path) -> str:
"""Convert document to markdown using simplified pipeline"""
try:
# Convert to absolute path string
input_path = os.path.abspath(str(doc_path))
print(f"Converting document: {doc_path}")
# Create temporary directory for processing
with tempfile.TemporaryDirectory() as temp_dir:
# Copy input file to temp directory
temp_input = os.path.join(temp_dir, os.path.basename(input_path))
shutil.copy2(input_path, temp_input)
# Configure pipeline options
pipeline_options = PdfPipelineOptions()
pipeline_options.do_ocr = False # Disable OCR for performance
pipeline_options.do_table_structure = True
# Create converter with optimized options
converter = DocumentConverter(
allowed_formats=[
InputFormat.PDF,
InputFormat.DOCX,
InputFormat.HTML,
InputFormat.PPTX,
],
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_options=pipeline_options,
),
InputFormat.DOCX: WordFormatOption(
pipeline_cls=SimplePipeline
)
}
)
# Convert document
print("Starting conversion...")
conv_result = converter.convert(temp_input)
if not conv_result or not conv_result.document:
raise ValueError(f"Failed to convert document: {doc_path}")
# Export to markdown
print("Exporting to markdown...")
md = conv_result.document.export_to_markdown()
# Create output path
output_dir = os.path.dirname(input_path)
base_name = os.path.splitext(os.path.basename(input_path))[0]
md_path = os.path.join(output_dir, f"{base_name}_converted.md")
# Write markdown file
with open(md_path, "w", encoding="utf-8") as fp:
fp.write(md)
return md_path
except Exception as e:
return f"Error converting document: {str(e)}"
# Improved text processing function
def clean_and_prepare_text(markdown_path):
"""Load, clean and prepare document text for better processing"""
try:
# Load the document
loader = UnstructuredMarkdownLoader(str(markdown_path))
documents = loader.load()
if not documents:
return None, "No content could be extracted from the document."
# Combine all document content for pre-processing
raw_text = " ".join([doc.page_content for doc in documents])
# Clean up the text
# 1. Normalize whitespace
text = " ".join(raw_text.split())
# 2. Fix common OCR and conversion artifacts
text = text.replace(" .", ".").replace(" ,", ",")
# 3. Ensure proper spacing after punctuation
for punct in ['.', '!', '?']:
text = text.replace(f"{punct}", f"{punct} ")
# Split into improved documents
# Use a sensible paragraph size
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
# Create structured documents for better processing
processed_docs = []
for i, para in enumerate(paragraphs):
if len(para) > 10: # Skip very short paragraphs
processed_docs.append(Document(
page_content=para,
metadata={"source": markdown_path, "paragraph": i}
))
return processed_docs, None
except Exception as e:
return None, f"Error processing document text: {str(e)}"
# Improved text splitting configuration
def create_optimized_text_splitter():
"""Create an optimized text splitter for document processing"""
return RecursiveCharacterTextSplitter(
chunk_size=800, # Slightly smaller for more focused chunks
chunk_overlap=150, # Increased overlap to maintain context
length_function=len,
separators=["\n\n", "\n", ".", "!", "?", ";", ":", " ", ""] # More comprehensive separators
)
# Function to generate a summary using the IBM Granite model
def generate_summary(chunks: List[Document], length_type="sentences", length_count=3):
"""Generate a summary from document chunks using the IBM Granite model
Args:
chunks: List of document chunks to summarize
length_type: Either "sentences" or "paragraphs"
length_count: Number of sentences (1-10) or paragraphs (1-3)
"""
# Print debug information
print(f"Generating summary with length_type={length_type}, length_count={length_count}")
# Ensure length_count is an integer
try:
length_count = int(length_count)
except (ValueError, TypeError):
print(f"Failed to convert length_count to int: {length_count}, using default 3")
length_count = 3
# Apply limits based on type
if length_type == "sentences":
length_count = max(1, min(10, length_count)) # Limit to 1-10 sentences
else: # paragraphs
length_count = max(1, min(3, length_count)) # Limit to 1-3 paragraphs
# Clean and concatenate the text from chunks
# Remove any excessive whitespace and normalize
cleaned_chunks = []
for chunk in chunks:
text = chunk.page_content
# Remove excessive newlines and whitespace
text = ' '.join(text.split())
cleaned_chunks.append(text)
combined_text = " ".join(cleaned_chunks)
# More explicit and forceful prompt structure
if length_type == "sentences":
length_instruction = f"Create a concise summary that is EXACTLY {length_count} complete sentences. Not {length_count-1} sentences. Not {length_count+1} sentences. EXACTLY {length_count} sentences."
else: # paragraphs
length_instruction = f"Create a concise summary that is EXACTLY {length_count} paragraphs. Each paragraph should be 2-4 sentences long. Not {length_count-1} paragraphs. Not {length_count+1} paragraphs. EXACTLY {length_count} paragraphs."
# More detailed prompt with examples of what constitutes a sentence
prompt = f"""
You are an expert document summarizer. Your task is to create a high-quality summary of the following text.
{length_instruction}
Remember:
- Your summary must capture the main points of the document
- Your summary must be in your own words (not copied text)
- Your summary must be clearly written and well-structured
- Do not include any explanations, headings, bullet points, or additional formatting
- Respond ONLY with the summary text itself
{combined_text}
"""
# Calculate appropriate max_new_tokens but with stricter limits
if length_type == "sentences":
# Approximately 20 tokens per sentence
max_tokens = length_count * 40
else: # paragraphs
# Approximately 100 tokens per paragraph
max_tokens = length_count * 150
# Ensure minimum tokens and add buffer
max_tokens = max(100, min(1500, max_tokens))
print(f"Using max_new_tokens={max_tokens}")
# Generate with lower temperature for more consistent results
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.3, # Lower temperature for more deterministic output
top_p=0.9,
do_sample=True,
repetition_penalty=1.2 # Discourage repetition
)
# Decode and return the generated summary
summary = tokenizer.decode(output[0], skip_special_tokens=True)
# Extract just the generated response (after the prompt)
summary = summary[len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)):]
summary = summary.strip()
# Post-process the summary to ensure it meets the length constraints
if length_type == "sentences":
# Simple sentence counting based on periods
sentences = [s.strip() for s in summary.split('.') if s.strip()]
if len(sentences) > length_count:
# Take only the requested number of sentences
summary = '. '.join(sentences[:length_count]) + '.'
elif len(sentences) < length_count:
# If we have too few sentences, log this issue
print(f"Warning: Generated only {len(sentences)} sentences instead of {length_count}")
return summary.strip()
# Function to process document chunks efficiently
def process_document_chunks(texts, batch_size=8):
"""Process document chunks in efficient batches"""
try:
# Create embeddings with optimized settings
embeddings = HuggingFaceEmbeddings(
model_name="nomic-ai/nomic-embed-text-v1",
model_kwargs={'trust_remote_code': True}
)
# Create vector store more efficiently
vectorstore = FAISS.from_documents(
texts,
embeddings,
# Add distance function for better retrieval
distance_strategy="cosine"
)
return vectorstore
except Exception as e:
print(f"Error in document processing: {str(e)}")
# Fallback to basic processing if optimization fails
embeddings = HuggingFaceEmbeddings(
model_name="nomic-ai/nomic-embed-text-v1",
model_kwargs={'trust_remote_code': True}
)
return FAISS.from_documents(texts, embeddings)
# Main function to process document and generate summary
@spaces.GPU
def process_document(
file_obj: Optional[Union[str, tempfile._TemporaryFileWrapper]] = None,
length_type: str = "sentences",
length_count: int = 3,
progress=gr.Progress()
):
"""Process a document file and generate a summary"""
try:
# Process input file
if not file_obj:
return "Please provide a file to summarize."
document_path = file_obj.name if hasattr(file_obj, 'name') else str(file_obj)
# Validate document format
format_type = get_document_format(document_path)
if not format_type:
return "Unsupported file format. Please upload a PDF, DOCX, PPTX, or HTML file."
# Convert document to markdown
progress(0.3, "Converting document to markdown...")
markdown_path = convert_document_to_markdown(document_path)
if markdown_path.startswith("Error"):
return markdown_path
# Clean and prepare the text
progress(0.4, "Processing document text...")
processed_docs, error = clean_and_prepare_text(markdown_path)
if error:
return error
# Split the documents with optimized splitter
text_splitter = create_optimized_text_splitter()
texts = text_splitter.split_documents(processed_docs)
if not texts:
return "No text could be extracted from the document."
# Create vector store with efficient processing
progress(0.6, "Processing document content...")
vectorstore = process_document_chunks(texts)
# Create retriever with optimized settings
retriever = vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": 4} # Number of chunks to retrieve
)
# Process chunks in smaller batches for memory efficiency
progress(0.8, "Generating summary...")
all_chunks = []
batch_size = 4 # Smaller batch size for memory efficiency
# Get all document chunks
doc_ids = list(vectorstore.index_to_docstore_id.values())
# Process in smaller batches
for i in range(0, len(doc_ids), batch_size):
batch_ids = doc_ids[i:i+batch_size]
batch_chunks = [vectorstore.docstore.search(doc_id) for doc_id in batch_ids]
all_chunks.extend(batch_chunks)
# Force garbage collection to free memory
gc.collect()
# Sleep briefly to allow memory cleanup
time.sleep(0.1)
# Case 1: Very small documents - use all chunks directly
if len(all_chunks) <= 8:
return generate_summary(
all_chunks,
length_type=length_type.lower(),
length_count=length_count
)
# Case 2: Medium-sized documents - process in one batch
elif len(all_chunks) <= 16:
return generate_summary(
all_chunks[:8], # Use first 8 chunks (usually contains most important info)
length_type=length_type.lower(),
length_count=length_count
)
# Case 3: Large documents - process in multiple batches
else:
# First pass: Generate summaries for each batch
summaries = []
for i in range(0, len(all_chunks), batch_size):
batch = all_chunks[i:i+batch_size]
summary = generate_summary(
batch,
length_type="paragraphs", # Use paragraphs for intermediate summaries
length_count=1 # One paragraph per batch
)
summaries.append(summary)
# Force garbage collection
gc.collect()
# Second pass: Generate final summary from batch summaries
final_summary = generate_summary(
[Document(page_content=s) for s in summaries],
length_type=length_type.lower(),
length_count=length_count
)
return final_summary
except Exception as e:
return f"Error processing document: {str(e)}"
# Create Gradio interface
def create_gradio_interface():
"""Create and launch the Gradio interface"""
with gr.Blocks(title="Granite Document Summarization") as app:
gr.Markdown("# Granite Document Summarization")
gr.Markdown("Upload a document to generate a summary.")
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="Upload Document (PDF, DOCX, PPTX, HTML)",
file_types=[".pdf", ".docx", ".doc", ".pptx", ".html", ".htm"]
)
with gr.Row():
length_type = gr.Radio(
choices=["Sentences", "Paragraphs"],
value="Sentences",
label="Summary Length Type"
)
with gr.Row():
# Use slider for sentence count (1-10)
sentence_count = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Number of Sentences",
visible=True
)
# Use radio for paragraph count (1-3)
paragraph_count = gr.Radio(
choices=["1", "2", "3"],
value="1",
label="Number of Paragraphs",
visible=False
)
submit_btn = gr.Button("Summarize", variant="primary")
with gr.Column(scale=2):
output = gr.TextArea(
label="Summary",
lines=15,
max_lines=30
)
# Add interactivity to show/hide appropriate count selector
def update_count_visibility(length_type):
is_sentences = length_type == "Sentences"
return [
gr.update(visible=is_sentences), # For sentence_count
gr.update(visible=not is_sentences) # For paragraph_count
]
length_type.change(
fn=update_count_visibility,
inputs=[length_type],
outputs=[sentence_count, paragraph_count]
)
# Function to handle form submission properly
def process_document_wrapper(file, length_type, sentence_count, paragraph_count):
# Convert capitalized length_type to lowercase for processing
length_type_lower = length_type.lower()
print(f"Processing with length_type={length_type}, sentence_count={sentence_count}, paragraph_count={paragraph_count}")
# Determine count based on the selected length type
if length_type_lower == "sentences":
# For sentences, use the slider value directly
try:
count = int(sentence_count)
count = max(1, min(10, count)) # Ensure within range 1-10
print(f"Using sentence count: {count}")
except (ValueError, TypeError):
print(f"Invalid sentence count: {sentence_count}, using default 3")
count = 3
else:
# For paragraphs, convert from string to int if needed
try:
# Check if paragraph_count is a string (from radio button)
if isinstance(paragraph_count, str):
count = int(paragraph_count)
# Check if it's a boolean (from visibility toggle)
elif isinstance(paragraph_count, bool):
count = 1 # Default if boolean
else:
count = int(paragraph_count)
count = max(1, min(3, count)) # Ensure within range 1-3
print(f"Using paragraph count: {count}")
except (ValueError, TypeError):
print(f"Invalid paragraph count: {paragraph_count}, using default 1")
count = 1
return process_document(file, length_type_lower, count)
submit_btn.click(
fn=process_document_wrapper,
inputs=[file_input, length_type, sentence_count, paragraph_count],
outputs=output
)
gr.Markdown("""
## How to use:
1. Upload a document (PDF, DOCX, PPTX, HTML)
2. Choose your summary length preference:
- Number of Sentences (1-10)
- Number of Paragraphs (1-3)
3. Click "Summarize" to process the document
*This application uses the IBM Granite 3.3-8b model to generate summaries.*
""")
return app
# Launch the application
if __name__ == "__main__":
app = create_gradio_interface()
app.launch()