import streamlit as st import shelve import docx2txt import PyPDF2 import time # Used to simulate typing effect import nltk import re import os import time # already imported in your code from dotenv import load_dotenv import torch from sentence_transformers import SentenceTransformer, util nltk.download('punkt') import hashlib from nltk import sent_tokenize nltk.download('punkt_tab') from transformers import LEDTokenizer, LEDForConditionalGeneration from transformers import pipeline import asyncio import dateutil.parser from datetime import datetime import sys from openai import OpenAI import numpy as np # Fix for RuntimeError: no running event loop on Windows if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) st.set_page_config(page_title="Legal Document Summarizer", layout="wide") if "processed" not in st.session_state: st.session_state.processed = False if "last_uploaded_hash" not in st.session_state: st.session_state.last_uploaded_hash = None if "chat_prompt_processed" not in st.session_state: st.session_state.chat_prompt_processed = False if "embedding_text" not in st.session_state: st.session_state.embedding_text = None if "document_context" not in st.session_state: st.session_state.document_context = None if "last_prompt_hash" not in st.session_state: st.session_state.last_prompt_hash = None st.title("πŸ“„ Legal Document Summarizer (Document Augmentation RAG)") USER_AVATAR = "πŸ‘€" BOT_AVATAR = "πŸ€–" # Load chat history def load_chat_history(): with shelve.open("chat_history") as db: return db.get("messages", []) # Save chat history def save_chat_history(messages): with shelve.open("chat_history") as db: db["messages"] = messages # Function to limit text preview to 500 words def limit_text(text, word_limit=500): words = text.split() return " ".join(words[:word_limit]) + ("..." if len(words) > word_limit else "") # CLEAN AND NORMALIZE TEXT def clean_text(text): # Remove newlines and extra spaces text = text.replace('\r\n', ' ').replace('\n', ' ') text = re.sub(r'\s+', ' ', text) # Remove page number markers like "Page 1 of 10" text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE) # Remove long dashed or underscored lines text = re.sub(r'[_]{5,}', '', text) # Lines with underscores: _____ text = re.sub(r'[-]{5,}', '', text) # Lines with hyphens: ----- # Remove long dotted separators text = re.sub(r'[.]{4,}', '', text) # Dots like "......" or ".............." # Trim final leading/trailing whitespace text = text.strip() return text ####################################################################################################################### # LOADING MODELS FOR DIVIDING TEXT INTO SECTIONS # Load token from .env file load_dotenv() HF_API_TOKEN = os.getenv("HF_API_TOKEN") client = OpenAI( base_url="https://api.studio.nebius.com/v1/", api_key=os.getenv("OPENAI_API_KEY") ) # print("API Key:", os.getenv("OPENAI_API_KEY")) # Temporary for debugging # Load once at the top (cache for performance) @st.cache_resource def load_local_zero_shot_classifier(): return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli") local_classifier = load_local_zero_shot_classifier() SECTION_LABELS = ["Facts", "Arguments", "Judgement", "Others"] def classify_chunk(text): result = local_classifier(text, candidate_labels=SECTION_LABELS) return result["labels"][0] # NEW: NLP-based sectioning using zero-shot classification def section_by_zero_shot(text): sections = {"Facts": "", "Arguments": "", "Judgment": "", "Others": ""} sentences = sent_tokenize(text) chunk = "" for i, sent in enumerate(sentences): chunk += sent + " " if (i + 1) % 3 == 0 or i == len(sentences) - 1: label = classify_chunk(chunk.strip()) print(f"πŸ”Ž Chunk: {chunk[:60]}...\nπŸ”– Predicted Label: {label}") # πŸ‘‡ Normalize label (title case and fallback) label = label.capitalize() if label not in sections: label = "Others" sections[label] += chunk + "\n" chunk = "" return sections ####################################################################################################################### # EXTRACTING TEXT FROM UPLOADED FILES # Function to extract text from uploaded file def extract_text(file): if file.name.endswith(".pdf"): reader = PyPDF2.PdfReader(file) full_text = "\n".join(page.extract_text() or "" for page in reader.pages) elif file.name.endswith(".docx"): full_text = docx2txt.process(file) elif file.name.endswith(".txt"): full_text = file.read().decode("utf-8") else: return "Unsupported file type." return full_text # Full text is needed for summarization ####################################################################################################################### # EXTRACTIVE AND ABSTRACTIVE SUMMARIZATION @st.cache_resource def load_legalbert(): return SentenceTransformer("nlpaueb/legal-bert-base-uncased") legalbert_model = load_legalbert() @st.cache_resource def load_led(): tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384") model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384") return tokenizer, model tokenizer_led, model_led = load_led() def legalbert_extractive_summary(text, top_ratio=0.2): sentences = sent_tokenize(text) top_k = max(3, int(len(sentences) * top_ratio)) if len(sentences) <= top_k: return text sentence_embeddings = legalbert_model.encode(sentences, convert_to_tensor=True) doc_embedding = torch.mean(sentence_embeddings, dim=0) cosine_scores = util.pytorch_cos_sim(doc_embedding, sentence_embeddings)[0] top_results = torch.topk(cosine_scores, k=top_k) selected_sentences = [sentences[i] for i in sorted(top_results.indices.tolist())] return " ".join(selected_sentences) # Add LED Abstractive Summarization def led_abstractive_summary(text, max_length=512, min_length=100): inputs = tokenizer_led( text, return_tensors="pt", padding="max_length", truncation=True, max_length=4096 ) global_attention_mask = torch.zeros_like(inputs["input_ids"]) global_attention_mask[:, 0] = 1 outputs = model_led.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], global_attention_mask=global_attention_mask, max_length=max_length, min_length=min_length, num_beams=4, # Use beam search repetition_penalty=2.0, # Penalize repetition length_penalty=1.0, early_stopping=True, no_repeat_ngram_size=4 # Prevent repeated phrases ) return tokenizer_led.decode(outputs[0], skip_special_tokens=True) def led_abstractive_summary_chunked(text, max_tokens=3000): sentences = sent_tokenize(text) current_chunk, chunks, summaries = "", [], [] for sent in sentences: if len(tokenizer_led(current_chunk + sent)["input_ids"]) > max_tokens: chunks.append(current_chunk) current_chunk = sent else: current_chunk += " " + sent if current_chunk: chunks.append(current_chunk) for chunk in chunks: inputs = tokenizer_led(chunk, return_tensors="pt", padding="max_length", truncation=True, max_length=4096) global_attention_mask = torch.zeros_like(inputs["input_ids"]) global_attention_mask[:, 0] = 1 output = model_led.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], global_attention_mask=global_attention_mask, max_length=512, min_length=100, num_beams=4, repetition_penalty=2.0, length_penalty=1.0, early_stopping=True, no_repeat_ngram_size=4, ) summaries.append(tokenizer_led.decode(output[0], skip_special_tokens=True)) return " ".join(summaries) def hybrid_summary_hierarchical(text, top_ratio=0.8): cleaned_text = clean_text(text) sections = section_by_zero_shot(cleaned_text) structured_summary = {} # <-- hierarchical summary here for name, content in sections.items(): if content.strip(): # Extractive summary extractive = legalbert_extractive_summary(content, top_ratio) # Abstractive summary abstractive = led_abstractive_summary_chunked(extractive) # Store in dictionary (hierarchical structure) structured_summary[name] = { "extractive": extractive, "abstractive": abstractive } return structured_summary def chunk_text_custom(text, n=1000, overlap=200): chunks = [] for i in range(0, len(text), n - overlap): chunks.append(text[i:i + n]) return chunks def get_embedding(text, model="BAAI/bge-en-icl"): """ From your notebook: Creates an embedding for the given text chunk using the BGE-ICL model. """ resp = client.embeddings.create(model=model, input=text) return np.array(resp.data[0].embedding) def semantic_search(query, text_chunks, chunk_embeddings, k=5): """ Compute cosine similarity between the query embedding and each chunk embedding, then pick the top-k chunks. """ q_emb = get_embedding(query) # simple cosine: def cosine(a, b): return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) scores = [cosine(q_emb, emb) for emb in chunk_embeddings] top_idxs = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k] return [text_chunks[i] for i in top_idxs] def generate_response(system_prompt, user_message, model="meta-llama/Llama-3.2-3B-Instruct"): return client.chat.completions.create( model=model, temperature=0, messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}] ).choices[0].message.content def generate_questions(text_chunk, num_questions=5, model="meta-llama/Llama-3.2-3B-Instruct"): system_prompt = ( "You are an expert at generating relevant questions from text. " "Create concise questions that can be answered using only the provided text." ) user_prompt = f""" Based on the following text, generate {num_questions} different questions that can be answered using only this text: {text_chunk} Format your response as a numbered list of questions only. """ resp = client.chat.completions.create( model=model, temperature=0.7, messages=[ {"role":"system","content":system_prompt}, {"role":"user","content":user_prompt} ] ) raw = resp.choices[0].message.content.strip() questions = [] for line in raw.split("\n"): q = re.sub(r"^\d+\.\s*", "", line).strip() if q.endswith("?"): questions.append(q) return questions # 2) EMBEDDINGS def create_embeddings(text, model="BAAI/bge-en-icl"): resp = client.embeddings.create(model=model, input=text) return resp.data[0].embedding def cosine_similarity(a,b): return float(np.dot(a,b)/(np.linalg.norm(a)*np.linalg.norm(b))) # 3) VECTOR STORE class SimpleVectorStore: def __init__(self): self.items = [] # each item is dict {text, embedding, metadata} def add_item(self, text, embedding, metadata): self.items.append(dict(text=text, embedding=embedding, metadata=metadata)) def search(self, query, k=5): q_emb = create_embeddings(query) scores = [(i, cosine_similarity(q_emb, item["embedding"])) for i,item in enumerate(self.items)] scores.sort(key=lambda x:x[1], reverse=True) return [self.items[i] for i,_ in scores[:k]] # 4) DOCUMENT PROCESSOR def process_document(raw_text, chunk_size=1000, chunk_overlap=200, questions_per_chunk=5): # chunk the text chunks = [] for i in range(0, len(raw_text), chunk_size - chunk_overlap): chunks.append(raw_text[i:i+chunk_size]) store = SimpleVectorStore() for idx,chunk in enumerate(chunks): # chunk embedding emb = create_embeddings(chunk) store.add_item(chunk, emb, {"type":"chunk","index":idx}) # generate Qs + their embeddings qs = generate_questions(chunk, num_questions=questions_per_chunk) for q in qs: q_emb = create_embeddings(q) store.add_item(q, q_emb, { "type":"question", "chunk_index":idx, "original_chunk": chunk }) return chunks, store # 5) CONTEXT BUILDER def prepare_context(results): seen = set() ctx = [] # first direct chunks for r in results: m = r["metadata"] if m["type"]=="chunk" and m["index"] not in seen: seen.add(m["index"]) ctx.append(f"Chunk {m['index']}:\n{r['text']}") # then referenced by questions for r in results: m = r["metadata"] if m["type"]=="question": ci = m["chunk_index"] if ci not in seen: seen.add(ci) ctx.append(f"Chunk {ci} (via Q β€œ{r['text']}”):\n{m['original_chunk']}") return "\n\n".join(ctx) # 6) ANSWER GENERATOR (overrides your old generate_response) def generate_response_from_context(query, context, model="meta-llama/Llama-3.2-3B-Instruct"): sp = ( "You are an AI assistant that strictly answers based on the given context. " "If the answer cannot be derived directly from the provided context, " "respond with: 'I do not have enough information to answer that.'" ) up = f""" Context: {context} Question: {query} Please answer the question based only on the context above. """ resp = client.chat.completions.create( model=model, temperature=0, messages=[{"role":"system","content":sp}, {"role":"user","content":up}] ) return resp.choices[0].message.content ####################################################################################################################### # STREAMLIT APP INTERFACE CODE # Initialize or load chat history if "messages" not in st.session_state: st.session_state.messages = load_chat_history() # Initialize last_uploaded if not set if "last_uploaded" not in st.session_state: st.session_state.last_uploaded = None # Sidebar with a button to delete chat history with st.sidebar: st.subheader("βš™οΈ Options") if st.button("Delete Chat History"): st.session_state.messages = [] st.session_state.last_uploaded = None st.session_state.processed = False st.session_state.chat_prompt_processed = False save_chat_history([]) # Display chat messages with a typing effect def display_with_typing_effect(text, speed=0.005): placeholder = st.empty() displayed_text = "" for char in text: displayed_text += char placeholder.markdown(displayed_text) time.sleep(speed) return displayed_text # Show existing chat messages for message in st.session_state.messages: avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR with st.chat_message(message["role"], avatar=avatar): st.markdown(message["content"]) # Standard chat input field prompt = st.chat_input("Type a message...") # Place uploader before the chat so it's always visible with st.container(): st.subheader("πŸ“Ž Upload a Legal Document") uploaded_file = st.file_uploader("Upload a file (PDF, DOCX, TXT)", type=["pdf", "docx", "txt"]) reprocess_btn = st.button("πŸ”„ Reprocess Last Uploaded File") # Hashing logic def get_file_hash(file): file.seek(0) content = file.read() file.seek(0) return hashlib.md5(content).hexdigest() # Function to prepare text for embedding # This function combines the extractive and abstractive summaries into a single string for embedding def prepare_text_for_embedding(summary_dict): combined_chunks = [] for section, content in summary_dict.items(): ext = content.get("extractive", "").strip() abs = content.get("abstractive", "").strip() if ext: combined_chunks.append(f"{section} - Extractive Summary:\n{ext}") if abs: combined_chunks.append(f"{section} - Abstractive Summary:\n{abs}") return "\n\n".join(combined_chunks) ############################################################################################################## user_role = st.sidebar.selectbox( "🎭 Select Your Role for Custom Summary", ["General", "Judge", "Lawyer", "Student"] ) def role_based_filter(section, summary, role): if role == "General": return summary filtered_summary = { "extractive": "", "abstractive": "" } if role == "Judge" and section in ["Judgement", "Facts"]: filtered_summary = summary elif role == "Lawyer" and section in ["Arguments", "Facts"]: filtered_summary = summary elif role == "Student" and section in ["Facts"]: filtered_summary = summary return filtered_summary ######################################################################################################################### if uploaded_file: file_hash = get_file_hash(uploaded_file) if file_hash != st.session_state.last_uploaded_hash or reprocess_btn: st.session_state.processed = False if not st.session_state.processed: start_time = time.time() # 1) extract & summarize as before raw_text = extract_text(uploaded_file) summary_dict = hybrid_summary_hierarchical(raw_text) embedding_text = prepare_text_for_embedding(summary_dict) # ─── NEW: document‐augmentation ingestion ─── chunks, store = process_document(raw_text, chunk_size=1000, chunk_overlap=200, questions_per_chunk=5) st.session_state.vector_store = store # ──────────────────────────────────────────── # 2) generate your β€œrole‐specific prompt” as before st.session_state.document_context = embedding_text if user_role == "General": role_specific_prompt = ( "Summarize the legal document focusing on the most relevant aspects " "such as facts, arguments, and judgments. Include key legal reasoning " "and a timeline of events where necessary." ) else: role_specific_prompt = ( f"As a {user_role}, summarize the legal document focusing on " "the most relevant aspects such as facts, arguments, and judgments " "tailored for your role. Include key legal reasoning and timeline of events." ) # ─── REPLACE rag_query_response with doc‐augmentation RAG ─── results = store.search(role_specific_prompt, k=5) context = prepare_context(results) rag_summary = generate_response_from_context(role_specific_prompt, context) # st.session_state.messages.append({ "role": "user", "content": f"πŸ“€ Uploaded **{uploaded_file.name}**" }) st.session_state.messages.append({ "role": "assistant", "content": rag_summary }) with st.chat_message("assistant", avatar=BOT_AVATAR): display_with_typing_effect(rag_summary) processing_time = round((time.time() - start_time) / 60, 2) st.info(f"⏱️ Response generated in **{processing_time} minutes**.") st.session_state.generated_summary = rag_summary st.session_state.last_uploaded_hash = file_hash st.session_state.processed = True st.session_state.last_prompt_hash = None save_chat_history(st.session_state.messages) if prompt: words = prompt.split() word_count = len(words) prompt_hash = hashlib.md5(prompt.encode("utf-8")).hexdigest() # 1) LONG prompts – echo & ingest like a β€œpaste‐in” document if word_count > 30 and prompt_hash != st.session_state.last_prompt_hash: st.session_state.last_prompt_hash = prompt_hash raw_text = prompt st.session_state.messages.append({ "role": "user", "content": f"πŸ“₯ **Pasted Document Text:**\n\n{limit_text(raw_text,500)}" }) with st.chat_message("user", avatar=USER_AVATAR): st.markdown(limit_text(raw_text,500)) start_time = time.time() # summarization + emb_text as before summary_dict = hybrid_summary_hierarchical(raw_text) emb_text = prepare_text_for_embedding(summary_dict) st.session_state.document_context = emb_text st.session_state.processed = True # ─── NEW: ingest via document‐augmentation ─── chunks, store = process_document(raw_text) st.session_state.vector_store = store if user_role == "General": role_prompt = ( "Summarize the document focusing on facts, arguments, judgments, " "and include a timeline of events." ) else: role_prompt = ( f"As a {user_role}, summarize the document focusing on facts, " "arguments, judgments, plus timeline of events." ) # ─── doc‐augmentation RAG here too ─── results = store.search(role_prompt, k=5) context = prepare_context(results) initial_summary = generate_response_from_context(role_prompt, context) st.session_state.messages.append({ "role": "assistant", "content": initial_summary }) with st.chat_message("assistant", avatar=BOT_AVATAR): display_with_typing_effect(initial_summary) st.info(f"⏱️ Summary generated in {round((time.time()-start_time)/60,2)} minutes") save_chat_history(st.session_state.messages) # 2) SHORT prompts – normal RAG against last ingested context elif word_count <= 30 and st.session_state.processed: with st.chat_message("user", avatar=USER_AVATAR): st.markdown(prompt) # 2) save to history st.session_state.messages.append({"role": "user", "content": prompt}) store = st.session_state.vector_store # ─── instead of rag_query_response, do doc‐augmentation RAG ─── results = store.search(prompt, k=5) context = prepare_context(results) answer = generate_response_from_context(prompt, context) # st.session_state.messages.append({"role":"user", "content":prompt}) st.session_state.messages.append({"role":"assistant","content":answer}) with st.chat_message("assistant", avatar=BOT_AVATAR): display_with_typing_effect(answer) save_chat_history(st.session_state.messages) # 3) not enough input else: with st.chat_message("assistant", avatar=BOT_AVATAR): st.markdown("❗ Paste at least 30 words of your document to ingest it first.") ################################Evaluation########################### ###################################################################################################################### # πŸ“š Imports import evaluate from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction from sklearn.metrics import f1_score # πŸ“Œ Load Evaluators Once @st.cache_resource def load_evaluators(): rouge = evaluate.load("rouge") bertscore = evaluate.load("bertscore") return rouge, bertscore rouge, bertscore = load_evaluators() # πŸ“Œ Define Evaluation Functions def evaluate_summary(generated_summary, ground_truth_summary): """Evaluate ROUGE and BERTScore.""" rouge_result = rouge.compute(predictions=[generated_summary], references=[ground_truth_summary]) bert_result = bertscore.compute(predictions=[generated_summary], references=[ground_truth_summary], lang="en") return rouge_result, bert_result def exact_match(prediction, ground_truth): return int(prediction.strip().lower() == ground_truth.strip().lower()) def compute_bleu(prediction, ground_truth): reference = [ground_truth.strip().split()] candidate = prediction.strip().split() smoothie = SmoothingFunction().method4 return sentence_bleu(reference, candidate, smoothing_function=smoothie) def compute_f1(prediction, ground_truth): """Compute F1 score based on token overlap, like in QA evaluation.""" pred_tokens = prediction.strip().lower().split() gt_tokens = ground_truth.strip().lower().split() common_tokens = set(pred_tokens) & set(gt_tokens) num_common = len(common_tokens) if num_common == 0: return 0.0 precision = num_common / len(pred_tokens) recall = num_common / len(gt_tokens) f1 = 2 * (precision * recall) / (precision + recall) return f1 def evaluate_additional_metrics(prediction, ground_truth): em = exact_match(prediction, ground_truth) bleu = compute_bleu(prediction, ground_truth) f1 = compute_f1(prediction, ground_truth) return { "Exact Match": em, "BLEU Score": bleu, "F1 Score": f1 } # πŸ“₯ Upload and Evaluate ground_truth_summary_file = st.file_uploader("πŸ“„ Upload Ground Truth Summary (.txt)", type=["txt"]) if ground_truth_summary_file: ground_truth_summary = ground_truth_summary_file.read().decode("utf-8").strip() if "generated_summary" in st.session_state and st.session_state.generated_summary: prediction = st.session_state.generated_summary # Evaluate ROUGE and BERTScore rouge_result, bert_result = evaluate_summary(prediction, ground_truth_summary) # Display ROUGE and BERTScore st.subheader("πŸ“Š Evaluation Results") st.write("πŸ”Ή ROUGE Scores:") st.json(rouge_result) st.write("πŸ”Ή BERTScore:") st.json(bert_result) # Evaluate and Display Exact Match, BLEU, F1 additional_metrics = evaluate_additional_metrics(prediction, ground_truth_summary) st.subheader("πŸ”Ž Additional Evaluation Metrics") st.json(additional_metrics) else: st.warning("⚠️ Please generate a summary first by uploading a document.") ###################################################################################################################### # Run this along with streamlit run app.py to evaluate the model's performance on a test set # Otherwise, comment the below code # β‡’ EVALUATION HOOK: after the very first summary, fire off evaluate.main() once # import json # import pandas as pd # import threading # def run_eval(doc_context): # with open("test_case1.json", "r", encoding="utf-8") as f: # gt_data = json.load(f) # # 2) map document_id β†’ local file # records = [] # for entry in gt_data: # doc_id = entry["document_id"] # query = entry["query"] # gt_ans = entry["ground_truth_answer"] # # model_ans = rag_query_response(query, emb_text) # model_ans = rag_query_response(query, doc_context) # records.append({ # "document_id": doc_id, # "query": query, # "ground_truth_answer": gt_ans, # "model_answer": model_ans # }) # print(f"βœ… Done {doc_id} / β€œ{query}”") # # 3) push to DataFrame + CSV # df = pd.DataFrame(records) # out = "evaluation_results.csv" # df.to_csv(out, index=False, encoding="utf-8") # print(f"\nπŸ“ Saved {len(df)} rows to {out}") # # you could log this somewhere # def _run_evaluation(): # try: # run_eval() # except Exception as e: # print("‼️ Evaluation script error:", e) # if st.session_state.processed and not st.session_state.get("evaluation_launched", False): # st.session_state.evaluation_launched = True # # inform user # st.sidebar.info("πŸ”¬ Starting background evaluation run…") # # *capture* the context # doc_ctx = st.session_state.document_context # # spawn the thread, passing doc_ctx in # threading.Thread( # target=lambda: run_eval(doc_ctx), # daemon=True # ).start() # st.sidebar.success("βœ… Evaluation launched β€” check evaluation_results.csv when done.") # # check for file existence & show download button # eval_path = os.path.abspath("evaluation_results.csv") # if os.path.exists(eval_path): # st.sidebar.success(f"βœ… Results saved to:\n`{eval_path}`") # # load it into a small dataframe (optional) # df_eval = pd.read_csv(eval_path) # # add a download button # st.sidebar.download_button( # label="⬇️ Download evaluation_results.csv", # data=df_eval.to_csv(index=False).encode("utf-8"), # file_name="evaluation_results.csv", # mime="text/csv" # ) # else: # # if you want, display the cwd so you can inspect it # st.sidebar.info(f"Current working dir:\n`{os.getcwd()}`")