import os import tempfile import shutil import torch import gradio as gr from pathlib import Path from typing import Optional, List, Union import gc import time # Docling imports from docling.datamodel.base_models import InputFormat from docling.datamodel.pipeline_options import PdfPipelineOptions from docling.document_converter import DocumentConverter, PdfFormatOption, WordFormatOption, SimplePipeline # LangChain imports from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.schema import Document # Transformers imports for IBM Granite model import spaces from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # Initialize IBM Granite model and tokenizer print("Loading Granite model and tokenizer...") model_name = "ibm-granite/granite-3.3-8b-instruct" # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) # Create quantization config quantization_config = BitsAndBytesConfig( load_in_4bit=True, # Use 4-bit quantization for better memory efficiency bnb_4bit_compute_dtype=torch.bfloat16 # Use bfloat16 for computation with 4-bit quantization ) # Load model with optimization for GPU model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", quantization_config=quantization_config ) print("Model loaded successfully!") # Helper function to detect document format def get_document_format(file_path) -> Optional[InputFormat]: """Determine the document format based on file extension""" try: file_path = str(file_path) extension = os.path.splitext(file_path)[1].lower() format_map = { '.pdf': InputFormat.PDF, '.docx': InputFormat.DOCX, '.doc': InputFormat.DOCX, '.pptx': InputFormat.PPTX, '.html': InputFormat.HTML, '.htm': InputFormat.HTML } return format_map.get(extension) except Exception as e: print(f"Error in get_document_format: {str(e)}") return None # Function to convert documents to markdown def convert_document_to_markdown(doc_path) -> str: """Convert document to markdown using simplified pipeline""" try: # Convert to absolute path string input_path = os.path.abspath(str(doc_path)) print(f"Converting document: {doc_path}") # Create temporary directory for processing with tempfile.TemporaryDirectory() as temp_dir: # Copy input file to temp directory temp_input = os.path.join(temp_dir, os.path.basename(input_path)) shutil.copy2(input_path, temp_input) # Configure pipeline options pipeline_options = PdfPipelineOptions() pipeline_options.do_ocr = False # Disable OCR for performance pipeline_options.do_table_structure = True # Create converter with optimized options converter = DocumentConverter( allowed_formats=[ InputFormat.PDF, InputFormat.DOCX, InputFormat.HTML, InputFormat.PPTX, ], format_options={ InputFormat.PDF: PdfFormatOption( pipeline_options=pipeline_options, ), InputFormat.DOCX: WordFormatOption( pipeline_cls=SimplePipeline ) } ) # Convert document print("Starting conversion...") conv_result = converter.convert(temp_input) if not conv_result or not conv_result.document: raise ValueError(f"Failed to convert document: {doc_path}") # Export to markdown print("Exporting to markdown...") md = conv_result.document.export_to_markdown() # Create output path output_dir = os.path.dirname(input_path) base_name = os.path.splitext(os.path.basename(input_path))[0] md_path = os.path.join(output_dir, f"{base_name}_converted.md") # Write markdown file with open(md_path, "w", encoding="utf-8") as fp: fp.write(md) return md_path except Exception as e: return f"Error converting document: {str(e)}" # Improved text processing function def clean_and_prepare_text(markdown_path): """Load, clean and prepare document text for better processing""" try: # Load the document loader = UnstructuredMarkdownLoader(str(markdown_path)) documents = loader.load() if not documents: return None, "No content could be extracted from the document." # Combine all document content for pre-processing raw_text = " ".join([doc.page_content for doc in documents]) # Clean up the text # 1. Normalize whitespace text = " ".join(raw_text.split()) # 2. Fix common OCR and conversion artifacts text = text.replace(" .", ".").replace(" ,", ",") # 3. Ensure proper spacing after punctuation for punct in ['.', '!', '?']: text = text.replace(f"{punct}", f"{punct} ") # Split into improved documents # Use a sensible paragraph size paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] # Create structured documents for better processing processed_docs = [] for i, para in enumerate(paragraphs): if len(para) > 10: # Skip very short paragraphs processed_docs.append(Document( page_content=para, metadata={"source": markdown_path, "paragraph": i} )) return processed_docs, None except Exception as e: return None, f"Error processing document text: {str(e)}" # Improved text splitting configuration def create_optimized_text_splitter(): """Create an optimized text splitter for document processing""" return RecursiveCharacterTextSplitter( chunk_size=800, # Slightly smaller for more focused chunks chunk_overlap=150, # Increased overlap to maintain context length_function=len, separators=["\n\n", "\n", ".", "!", "?", ";", ":", " ", ""] # More comprehensive separators ) # Function to generate a summary using the IBM Granite model def generate_summary(chunks: List[Document], length_type="sentences", length_count=3): """Generate a summary from document chunks using the IBM Granite model Args: chunks: List of document chunks to summarize length_type: Either "sentences" or "paragraphs" length_count: Number of sentences (1-10) or paragraphs (1-3) """ # Print debug information print(f"Generating summary with length_type={length_type}, length_count={length_count}") # Ensure length_count is an integer try: length_count = int(length_count) except (ValueError, TypeError): print(f"Failed to convert length_count to int: {length_count}, using default 3") length_count = 3 # Apply limits based on type if length_type == "sentences": length_count = max(1, min(10, length_count)) # Limit to 1-10 sentences else: # paragraphs length_count = max(1, min(3, length_count)) # Limit to 1-3 paragraphs # Clean and concatenate the text from chunks # Remove any excessive whitespace and normalize cleaned_chunks = [] for chunk in chunks: text = chunk.page_content # Remove excessive newlines and whitespace text = ' '.join(text.split()) cleaned_chunks.append(text) combined_text = " ".join(cleaned_chunks) # More explicit and forceful prompt structure if length_type == "sentences": length_instruction = f"Create a concise summary that is EXACTLY {length_count} complete sentences. Not {length_count-1} sentences. Not {length_count+1} sentences. EXACTLY {length_count} sentences." else: # paragraphs length_instruction = f"Create a concise summary that is EXACTLY {length_count} paragraphs. Each paragraph should be 2-4 sentences long. Not {length_count-1} paragraphs. Not {length_count+1} paragraphs. EXACTLY {length_count} paragraphs." # More detailed prompt with examples of what constitutes a sentence prompt = f""" You are an expert document summarizer. Your task is to create a high-quality summary of the following text. {length_instruction} Remember: - Your summary must capture the main points of the document - Your summary must be in your own words (not copied text) - Your summary must be clearly written and well-structured - Do not include any explanations, headings, bullet points, or additional formatting - Respond ONLY with the summary text itself {combined_text} """ # Calculate appropriate max_new_tokens but with stricter limits if length_type == "sentences": # Approximately 20 tokens per sentence max_tokens = length_count * 40 else: # paragraphs # Approximately 100 tokens per paragraph max_tokens = length_count * 150 # Ensure minimum tokens and add buffer max_tokens = max(100, min(1500, max_tokens)) print(f"Using max_new_tokens={max_tokens}") # Generate with lower temperature for more consistent results inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=max_tokens, temperature=0.3, # Lower temperature for more deterministic output top_p=0.9, do_sample=True, repetition_penalty=1.2 # Discourage repetition ) # Decode and return the generated summary summary = tokenizer.decode(output[0], skip_special_tokens=True) # Extract just the generated response (after the prompt) summary = summary[len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)):] summary = summary.strip() # Post-process the summary to ensure it meets the length constraints if length_type == "sentences": # Simple sentence counting based on periods sentences = [s.strip() for s in summary.split('.') if s.strip()] if len(sentences) > length_count: # Take only the requested number of sentences summary = '. '.join(sentences[:length_count]) + '.' elif len(sentences) < length_count: # If we have too few sentences, log this issue print(f"Warning: Generated only {len(sentences)} sentences instead of {length_count}") return summary.strip() # Function to process document chunks efficiently def process_document_chunks(texts, batch_size=8): """Process document chunks in efficient batches""" try: # Create embeddings with optimized settings embeddings = HuggingFaceEmbeddings( model_name="nomic-ai/nomic-embed-text-v1", model_kwargs={'trust_remote_code': True} ) # Create vector store more efficiently vectorstore = FAISS.from_documents( texts, embeddings, # Add distance function for better retrieval distance_strategy="cosine" ) return vectorstore except Exception as e: print(f"Error in document processing: {str(e)}") # Fallback to basic processing if optimization fails embeddings = HuggingFaceEmbeddings( model_name="nomic-ai/nomic-embed-text-v1", model_kwargs={'trust_remote_code': True} ) return FAISS.from_documents(texts, embeddings) # Main function to process document and generate summary @spaces.GPU def process_document( file_obj: Optional[Union[str, tempfile._TemporaryFileWrapper]] = None, length_type: str = "sentences", length_count: int = 3, progress=gr.Progress() ): """Process a document file and generate a summary""" try: # Process input file if not file_obj: return "Please provide a file to summarize." document_path = file_obj.name if hasattr(file_obj, 'name') else str(file_obj) # Validate document format format_type = get_document_format(document_path) if not format_type: return "Unsupported file format. Please upload a PDF, DOCX, PPTX, or HTML file." # Convert document to markdown progress(0.3, "Converting document to markdown...") markdown_path = convert_document_to_markdown(document_path) if markdown_path.startswith("Error"): return markdown_path # Clean and prepare the text progress(0.4, "Processing document text...") processed_docs, error = clean_and_prepare_text(markdown_path) if error: return error # Split the documents with optimized splitter text_splitter = create_optimized_text_splitter() texts = text_splitter.split_documents(processed_docs) if not texts: return "No text could be extracted from the document." # Create vector store with efficient processing progress(0.6, "Processing document content...") vectorstore = process_document_chunks(texts) # Create retriever with optimized settings retriever = vectorstore.as_retriever( search_type="similarity", search_kwargs={"k": 4} # Number of chunks to retrieve ) # Process chunks in smaller batches for memory efficiency progress(0.8, "Generating summary...") all_chunks = [] batch_size = 4 # Smaller batch size for memory efficiency # Get all document chunks doc_ids = list(vectorstore.index_to_docstore_id.values()) # Process in smaller batches for i in range(0, len(doc_ids), batch_size): batch_ids = doc_ids[i:i+batch_size] batch_chunks = [vectorstore.docstore.search(doc_id) for doc_id in batch_ids] all_chunks.extend(batch_chunks) # Force garbage collection to free memory gc.collect() # Sleep briefly to allow memory cleanup time.sleep(0.1) # Case 1: Very small documents - use all chunks directly if len(all_chunks) <= 8: return generate_summary( all_chunks, length_type=length_type.lower(), length_count=length_count ) # Case 2: Medium-sized documents - process in one batch elif len(all_chunks) <= 16: return generate_summary( all_chunks[:8], # Use first 8 chunks (usually contains most important info) length_type=length_type.lower(), length_count=length_count ) # Case 3: Large documents - process in multiple batches else: # First pass: Generate summaries for each batch summaries = [] for i in range(0, len(all_chunks), batch_size): batch = all_chunks[i:i+batch_size] summary = generate_summary( batch, length_type="paragraphs", # Use paragraphs for intermediate summaries length_count=1 # One paragraph per batch ) summaries.append(summary) # Force garbage collection gc.collect() # Second pass: Generate final summary from batch summaries final_summary = generate_summary( [Document(page_content=s) for s in summaries], length_type=length_type.lower(), length_count=length_count ) return final_summary except Exception as e: return f"Error processing document: {str(e)}" # Create Gradio interface def create_gradio_interface(): """Create and launch the Gradio interface""" with gr.Blocks(title="Granite Document Summarization") as app: gr.Markdown("# Granite Document Summarization") gr.Markdown("Upload a document to generate a summary.") with gr.Row(): with gr.Column(scale=1): file_input = gr.File( label="Upload Document (PDF, DOCX, PPTX, HTML)", file_types=[".pdf", ".docx", ".doc", ".pptx", ".html", ".htm"] ) with gr.Row(): length_type = gr.Radio( choices=["Sentences", "Paragraphs"], value="Sentences", label="Summary Length Type" ) with gr.Row(): # Use slider for sentence count (1-10) sentence_count = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Number of Sentences", visible=True ) # Use radio for paragraph count (1-3) paragraph_count = gr.Radio( choices=["1", "2", "3"], value="1", label="Number of Paragraphs", visible=False ) submit_btn = gr.Button("Summarize", variant="primary") with gr.Column(scale=2): output = gr.TextArea( label="Summary", lines=15, max_lines=30 ) # Add interactivity to show/hide appropriate count selector def update_count_visibility(length_type): is_sentences = length_type == "Sentences" return [ gr.update(visible=is_sentences), # For sentence_count gr.update(visible=not is_sentences) # For paragraph_count ] length_type.change( fn=update_count_visibility, inputs=[length_type], outputs=[sentence_count, paragraph_count] ) # Function to handle form submission properly def process_document_wrapper(file, length_type, sentence_count, paragraph_count): # Convert capitalized length_type to lowercase for processing length_type_lower = length_type.lower() print(f"Processing with length_type={length_type}, sentence_count={sentence_count}, paragraph_count={paragraph_count}") # Determine count based on the selected length type if length_type_lower == "sentences": # For sentences, use the slider value directly try: count = int(sentence_count) count = max(1, min(10, count)) # Ensure within range 1-10 print(f"Using sentence count: {count}") except (ValueError, TypeError): print(f"Invalid sentence count: {sentence_count}, using default 3") count = 3 else: # For paragraphs, convert from string to int if needed try: # Check if paragraph_count is a string (from radio button) if isinstance(paragraph_count, str): count = int(paragraph_count) # Check if it's a boolean (from visibility toggle) elif isinstance(paragraph_count, bool): count = 1 # Default if boolean else: count = int(paragraph_count) count = max(1, min(3, count)) # Ensure within range 1-3 print(f"Using paragraph count: {count}") except (ValueError, TypeError): print(f"Invalid paragraph count: {paragraph_count}, using default 1") count = 1 return process_document(file, length_type_lower, count) submit_btn.click( fn=process_document_wrapper, inputs=[file_input, length_type, sentence_count, paragraph_count], outputs=output ) gr.Markdown(""" ## How to use: 1. Upload a document (PDF, DOCX, PPTX, HTML) 2. Choose your summary length preference: - Number of Sentences (1-10) - Number of Paragraphs (1-3) 3. Click "Summarize" to process the document *This application uses the IBM Granite 3.3-8b model to generate summaries.* """) return app # Launch the application if __name__ == "__main__": app = create_gradio_interface() app.launch()