File size: 29,945 Bytes
ce7b020 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 |
import streamlit as st
import shelve
import docx2txt
import PyPDF2
import time # Used to simulate typing effect
import nltk
import re
import os
import time # already imported in your code
from dotenv import load_dotenv
import torch
from sentence_transformers import SentenceTransformer, util
nltk.download('punkt')
import hashlib
from nltk import sent_tokenize
nltk.download('punkt_tab')
from transformers import LEDTokenizer, LEDForConditionalGeneration
from transformers import pipeline
import asyncio
import dateutil.parser
from datetime import datetime
import sys
from openai import OpenAI
import numpy as np
# Fix for RuntimeError: no running event loop on Windows
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
st.set_page_config(page_title="Legal Document Summarizer", layout="wide")
if "processed" not in st.session_state:
st.session_state.processed = False
if "last_uploaded_hash" not in st.session_state:
st.session_state.last_uploaded_hash = None
if "chat_prompt_processed" not in st.session_state:
st.session_state.chat_prompt_processed = False
if "embedding_text" not in st.session_state:
st.session_state.embedding_text = None
if "document_context" not in st.session_state:
st.session_state.document_context = None
if "last_prompt_hash" not in st.session_state:
st.session_state.last_prompt_hash = None
st.title("π Legal Document Summarizer (Document Augmentation RAG)")
USER_AVATAR = "π€"
BOT_AVATAR = "π€"
# Load chat history
def load_chat_history():
with shelve.open("chat_history") as db:
return db.get("messages", [])
# Save chat history
def save_chat_history(messages):
with shelve.open("chat_history") as db:
db["messages"] = messages
# Function to limit text preview to 500 words
def limit_text(text, word_limit=500):
words = text.split()
return " ".join(words[:word_limit]) + ("..." if len(words) > word_limit else "")
# CLEAN AND NORMALIZE TEXT
def clean_text(text):
# Remove newlines and extra spaces
text = text.replace('\r\n', ' ').replace('\n', ' ')
text = re.sub(r'\s+', ' ', text)
# Remove page number markers like "Page 1 of 10"
text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE)
# Remove long dashed or underscored lines
text = re.sub(r'[_]{5,}', '', text) # Lines with underscores: _____
text = re.sub(r'[-]{5,}', '', text) # Lines with hyphens: -----
# Remove long dotted separators
text = re.sub(r'[.]{4,}', '', text) # Dots like "......" or ".............."
# Trim final leading/trailing whitespace
text = text.strip()
return text
#######################################################################################################################
# LOADING MODELS FOR DIVIDING TEXT INTO SECTIONS
# Load token from .env file
load_dotenv()
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
client = OpenAI(
base_url="https://api.studio.nebius.com/v1/",
api_key=os.getenv("OPENAI_API_KEY")
)
# print("API Key:", os.getenv("OPENAI_API_KEY")) # Temporary for debugging
# Load once at the top (cache for performance)
@st.cache_resource
def load_local_zero_shot_classifier():
return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
local_classifier = load_local_zero_shot_classifier()
SECTION_LABELS = ["Facts", "Arguments", "Judgement", "Others"]
def classify_chunk(text):
result = local_classifier(text, candidate_labels=SECTION_LABELS)
return result["labels"][0]
# NEW: NLP-based sectioning using zero-shot classification
def section_by_zero_shot(text):
sections = {"Facts": "", "Arguments": "", "Judgment": "", "Others": ""}
sentences = sent_tokenize(text)
chunk = ""
for i, sent in enumerate(sentences):
chunk += sent + " "
if (i + 1) % 3 == 0 or i == len(sentences) - 1:
label = classify_chunk(chunk.strip())
print(f"π Chunk: {chunk[:60]}...\nπ Predicted Label: {label}")
# π Normalize label (title case and fallback)
label = label.capitalize()
if label not in sections:
label = "Others"
sections[label] += chunk + "\n"
chunk = ""
return sections
#######################################################################################################################
# EXTRACTING TEXT FROM UPLOADED FILES
# Function to extract text from uploaded file
def extract_text(file):
if file.name.endswith(".pdf"):
reader = PyPDF2.PdfReader(file)
full_text = "\n".join(page.extract_text() or "" for page in reader.pages)
elif file.name.endswith(".docx"):
full_text = docx2txt.process(file)
elif file.name.endswith(".txt"):
full_text = file.read().decode("utf-8")
else:
return "Unsupported file type."
return full_text # Full text is needed for summarization
#######################################################################################################################
# EXTRACTIVE AND ABSTRACTIVE SUMMARIZATION
@st.cache_resource
def load_legalbert():
return SentenceTransformer("nlpaueb/legal-bert-base-uncased")
legalbert_model = load_legalbert()
@st.cache_resource
def load_led():
tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")
model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
return tokenizer, model
tokenizer_led, model_led = load_led()
def legalbert_extractive_summary(text, top_ratio=0.2):
sentences = sent_tokenize(text)
top_k = max(3, int(len(sentences) * top_ratio))
if len(sentences) <= top_k:
return text
sentence_embeddings = legalbert_model.encode(sentences, convert_to_tensor=True)
doc_embedding = torch.mean(sentence_embeddings, dim=0)
cosine_scores = util.pytorch_cos_sim(doc_embedding, sentence_embeddings)[0]
top_results = torch.topk(cosine_scores, k=top_k)
selected_sentences = [sentences[i] for i in sorted(top_results.indices.tolist())]
return " ".join(selected_sentences)
# Add LED Abstractive Summarization
def led_abstractive_summary(text, max_length=512, min_length=100):
inputs = tokenizer_led(
text, return_tensors="pt", padding="max_length",
truncation=True, max_length=4096
)
global_attention_mask = torch.zeros_like(inputs["input_ids"])
global_attention_mask[:, 0] = 1
outputs = model_led.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=global_attention_mask,
max_length=max_length,
min_length=min_length,
num_beams=4, # Use beam search
repetition_penalty=2.0, # Penalize repetition
length_penalty=1.0,
early_stopping=True,
no_repeat_ngram_size=4 # Prevent repeated phrases
)
return tokenizer_led.decode(outputs[0], skip_special_tokens=True)
def led_abstractive_summary_chunked(text, max_tokens=3000):
sentences = sent_tokenize(text)
current_chunk, chunks, summaries = "", [], []
for sent in sentences:
if len(tokenizer_led(current_chunk + sent)["input_ids"]) > max_tokens:
chunks.append(current_chunk)
current_chunk = sent
else:
current_chunk += " " + sent
if current_chunk:
chunks.append(current_chunk)
for chunk in chunks:
inputs = tokenizer_led(chunk, return_tensors="pt", padding="max_length", truncation=True, max_length=4096)
global_attention_mask = torch.zeros_like(inputs["input_ids"])
global_attention_mask[:, 0] = 1
output = model_led.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=global_attention_mask,
max_length=512,
min_length=100,
num_beams=4,
repetition_penalty=2.0,
length_penalty=1.0,
early_stopping=True,
no_repeat_ngram_size=4,
)
summaries.append(tokenizer_led.decode(output[0], skip_special_tokens=True))
return " ".join(summaries)
def hybrid_summary_hierarchical(text, top_ratio=0.8):
cleaned_text = clean_text(text)
sections = section_by_zero_shot(cleaned_text)
structured_summary = {} # <-- hierarchical summary here
for name, content in sections.items():
if content.strip():
# Extractive summary
extractive = legalbert_extractive_summary(content, top_ratio)
# Abstractive summary
abstractive = led_abstractive_summary_chunked(extractive)
# Store in dictionary (hierarchical structure)
structured_summary[name] = {
"extractive": extractive,
"abstractive": abstractive
}
return structured_summary
def chunk_text_custom(text, n=1000, overlap=200):
chunks = []
for i in range(0, len(text), n - overlap):
chunks.append(text[i:i + n])
return chunks
def get_embedding(text, model="BAAI/bge-en-icl"):
"""
From your notebook:
Creates an embedding for the given text chunk using the BGE-ICL model.
"""
resp = client.embeddings.create(model=model, input=text)
return np.array(resp.data[0].embedding)
def semantic_search(query, text_chunks, chunk_embeddings, k=5):
"""
Compute cosine similarity between the query embedding and each chunk embedding,
then pick the top-k chunks.
"""
q_emb = get_embedding(query)
# simple cosine:
def cosine(a, b): return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
scores = [cosine(q_emb, emb) for emb in chunk_embeddings]
top_idxs = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]
return [text_chunks[i] for i in top_idxs]
def generate_response(system_prompt, user_message, model="meta-llama/Llama-3.2-3B-Instruct"):
return client.chat.completions.create(
model=model,
temperature=0,
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}]
).choices[0].message.content
def generate_questions(text_chunk, num_questions=5,
model="meta-llama/Llama-3.2-3B-Instruct"):
system_prompt = (
"You are an expert at generating relevant questions from text. "
"Create concise questions that can be answered using only the provided text."
)
user_prompt = f"""
Based on the following text, generate {num_questions} different questions
that can be answered using only this text:
{text_chunk}
Format your response as a numbered list of questions only.
"""
resp = client.chat.completions.create(
model=model,
temperature=0.7,
messages=[
{"role":"system","content":system_prompt},
{"role":"user","content":user_prompt}
]
)
raw = resp.choices[0].message.content.strip()
questions = []
for line in raw.split("\n"):
q = re.sub(r"^\d+\.\s*", "", line).strip()
if q.endswith("?"):
questions.append(q)
return questions
# 2) EMBEDDINGS
def create_embeddings(text, model="BAAI/bge-en-icl"):
resp = client.embeddings.create(model=model, input=text)
return resp.data[0].embedding
def cosine_similarity(a,b):
return float(np.dot(a,b)/(np.linalg.norm(a)*np.linalg.norm(b)))
# 3) VECTOR STORE
class SimpleVectorStore:
def __init__(self):
self.items = [] # each item is dict {text, embedding, metadata}
def add_item(self, text, embedding, metadata):
self.items.append(dict(text=text, embedding=embedding, metadata=metadata))
def search(self, query, k=5):
q_emb = create_embeddings(query)
scores = [(i, cosine_similarity(q_emb, item["embedding"]))
for i,item in enumerate(self.items)]
scores.sort(key=lambda x:x[1], reverse=True)
return [self.items[i] for i,_ in scores[:k]]
# 4) DOCUMENT PROCESSOR
def process_document(raw_text,
chunk_size=1000,
chunk_overlap=200,
questions_per_chunk=5):
# chunk the text
chunks = []
for i in range(0, len(raw_text), chunk_size - chunk_overlap):
chunks.append(raw_text[i:i+chunk_size])
store = SimpleVectorStore()
for idx,chunk in enumerate(chunks):
# chunk embedding
emb = create_embeddings(chunk)
store.add_item(chunk, emb, {"type":"chunk","index":idx})
# generate Qs + their embeddings
qs = generate_questions(chunk, num_questions=questions_per_chunk)
for q in qs:
q_emb = create_embeddings(q)
store.add_item(q, q_emb, {
"type":"question",
"chunk_index":idx,
"original_chunk": chunk
})
return chunks, store
# 5) CONTEXT BUILDER
def prepare_context(results):
seen = set()
ctx = []
# first direct chunks
for r in results:
m = r["metadata"]
if m["type"]=="chunk" and m["index"] not in seen:
seen.add(m["index"])
ctx.append(f"Chunk {m['index']}:\n{r['text']}")
# then referenced by questions
for r in results:
m = r["metadata"]
if m["type"]=="question":
ci = m["chunk_index"]
if ci not in seen:
seen.add(ci)
ctx.append(f"Chunk {ci} (via Q β{r['text']}β):\n{m['original_chunk']}")
return "\n\n".join(ctx)
# 6) ANSWER GENERATOR (overrides your old generate_response)
def generate_response_from_context(query, context,
model="meta-llama/Llama-3.2-3B-Instruct"):
sp = (
"You are an AI assistant that strictly answers based on the given context. "
"If the answer cannot be derived directly from the provided context, "
"respond with: 'I do not have enough information to answer that.'"
)
up = f"""
Context:
{context}
Question: {query}
Please answer the question based only on the context above.
"""
resp = client.chat.completions.create(
model=model,
temperature=0,
messages=[{"role":"system","content":sp},
{"role":"user","content":up}]
)
return resp.choices[0].message.content
#######################################################################################################################
# STREAMLIT APP INTERFACE CODE
# Initialize or load chat history
if "messages" not in st.session_state:
st.session_state.messages = load_chat_history()
# Initialize last_uploaded if not set
if "last_uploaded" not in st.session_state:
st.session_state.last_uploaded = None
# Sidebar with a button to delete chat history
with st.sidebar:
st.subheader("βοΈ Options")
if st.button("Delete Chat History"):
st.session_state.messages = []
st.session_state.last_uploaded = None
st.session_state.processed = False
st.session_state.chat_prompt_processed = False
save_chat_history([])
# Display chat messages with a typing effect
def display_with_typing_effect(text, speed=0.005):
placeholder = st.empty()
displayed_text = ""
for char in text:
displayed_text += char
placeholder.markdown(displayed_text)
time.sleep(speed)
return displayed_text
# Show existing chat messages
for message in st.session_state.messages:
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
# Standard chat input field
prompt = st.chat_input("Type a message...")
# Place uploader before the chat so it's always visible
with st.container():
st.subheader("π Upload a Legal Document")
uploaded_file = st.file_uploader("Upload a file (PDF, DOCX, TXT)", type=["pdf", "docx", "txt"])
reprocess_btn = st.button("π Reprocess Last Uploaded File")
# Hashing logic
def get_file_hash(file):
file.seek(0)
content = file.read()
file.seek(0)
return hashlib.md5(content).hexdigest()
# Function to prepare text for embedding
# This function combines the extractive and abstractive summaries into a single string for embedding
def prepare_text_for_embedding(summary_dict):
combined_chunks = []
for section, content in summary_dict.items():
ext = content.get("extractive", "").strip()
abs = content.get("abstractive", "").strip()
if ext:
combined_chunks.append(f"{section} - Extractive Summary:\n{ext}")
if abs:
combined_chunks.append(f"{section} - Abstractive Summary:\n{abs}")
return "\n\n".join(combined_chunks)
##############################################################################################################
user_role = st.sidebar.selectbox(
"π Select Your Role for Custom Summary",
["General", "Judge", "Lawyer", "Student"]
)
def role_based_filter(section, summary, role):
if role == "General":
return summary
filtered_summary = {
"extractive": "",
"abstractive": ""
}
if role == "Judge" and section in ["Judgement", "Facts"]:
filtered_summary = summary
elif role == "Lawyer" and section in ["Arguments", "Facts"]:
filtered_summary = summary
elif role == "Student" and section in ["Facts"]:
filtered_summary = summary
return filtered_summary
#########################################################################################################################
if uploaded_file:
file_hash = get_file_hash(uploaded_file)
if file_hash != st.session_state.last_uploaded_hash or reprocess_btn:
st.session_state.processed = False
if not st.session_state.processed:
start_time = time.time()
# 1) extract & summarize as before
raw_text = extract_text(uploaded_file)
summary_dict = hybrid_summary_hierarchical(raw_text)
embedding_text = prepare_text_for_embedding(summary_dict)
# βββ NEW: documentβaugmentation ingestion βββ
chunks, store = process_document(raw_text,
chunk_size=1000,
chunk_overlap=200,
questions_per_chunk=5)
st.session_state.vector_store = store
# ββββββββββββββββββββββββββββββββββββββββββββ
# 2) generate your βroleβspecific promptβ as before
st.session_state.document_context = embedding_text
if user_role == "General":
role_specific_prompt = (
"Summarize the legal document focusing on the most relevant aspects "
"such as facts, arguments, and judgments. Include key legal reasoning "
"and a timeline of events where necessary."
)
else:
role_specific_prompt = (
f"As a {user_role}, summarize the legal document focusing on "
"the most relevant aspects such as facts, arguments, and judgments "
"tailored for your role. Include key legal reasoning and timeline of events."
)
# βββ REPLACE rag_query_response with docβaugmentation RAG βββ
results = store.search(role_specific_prompt, k=5)
context = prepare_context(results)
rag_summary = generate_response_from_context(role_specific_prompt, context)
#
st.session_state.messages.append({
"role": "user",
"content": f"π€ Uploaded **{uploaded_file.name}**"
})
st.session_state.messages.append({
"role": "assistant",
"content": rag_summary
})
with st.chat_message("assistant", avatar=BOT_AVATAR):
display_with_typing_effect(rag_summary)
processing_time = round((time.time() - start_time) / 60, 2)
st.info(f"β±οΈ Response generated in **{processing_time} minutes**.")
st.session_state.generated_summary = rag_summary
st.session_state.last_uploaded_hash = file_hash
st.session_state.processed = True
st.session_state.last_prompt_hash = None
save_chat_history(st.session_state.messages)
if prompt:
words = prompt.split()
word_count = len(words)
prompt_hash = hashlib.md5(prompt.encode("utf-8")).hexdigest()
# 1) LONG prompts β echo & ingest like a βpasteβinβ document
if word_count > 30 and prompt_hash != st.session_state.last_prompt_hash:
st.session_state.last_prompt_hash = prompt_hash
raw_text = prompt
st.session_state.messages.append({
"role": "user",
"content": f"π₯ **Pasted Document Text:**\n\n{limit_text(raw_text,500)}"
})
with st.chat_message("user", avatar=USER_AVATAR):
st.markdown(limit_text(raw_text,500))
start_time = time.time()
# summarization + emb_text as before
summary_dict = hybrid_summary_hierarchical(raw_text)
emb_text = prepare_text_for_embedding(summary_dict)
st.session_state.document_context = emb_text
st.session_state.processed = True
# βββ NEW: ingest via documentβaugmentation βββ
chunks, store = process_document(raw_text)
st.session_state.vector_store = store
if user_role == "General":
role_prompt = (
"Summarize the document focusing on facts, arguments, judgments, "
"and include a timeline of events."
)
else:
role_prompt = (
f"As a {user_role}, summarize the document focusing on facts, "
"arguments, judgments, plus timeline of events."
)
# βββ docβaugmentation RAG here too βββ
results = store.search(role_prompt, k=5)
context = prepare_context(results)
initial_summary = generate_response_from_context(role_prompt, context)
st.session_state.messages.append({
"role": "assistant",
"content": initial_summary
})
with st.chat_message("assistant", avatar=BOT_AVATAR):
display_with_typing_effect(initial_summary)
st.info(f"β±οΈ Summary generated in {round((time.time()-start_time)/60,2)} minutes")
save_chat_history(st.session_state.messages)
# 2) SHORT prompts β normal RAG against last ingested context
elif word_count <= 30 and st.session_state.processed:
with st.chat_message("user", avatar=USER_AVATAR):
st.markdown(prompt)
# 2) save to history
st.session_state.messages.append({"role": "user", "content": prompt})
store = st.session_state.vector_store
# βββ instead of rag_query_response, do docβaugmentation RAG βββ
results = store.search(prompt, k=5)
context = prepare_context(results)
answer = generate_response_from_context(prompt, context)
# st.session_state.messages.append({"role":"user", "content":prompt})
st.session_state.messages.append({"role":"assistant","content":answer})
with st.chat_message("assistant", avatar=BOT_AVATAR):
display_with_typing_effect(answer)
save_chat_history(st.session_state.messages)
# 3) not enough input
else:
with st.chat_message("assistant", avatar=BOT_AVATAR):
st.markdown("β Paste at least 30 words of your document to ingest it first.")
################################Evaluation###########################
######################################################################################################################
# π Imports
import evaluate
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from sklearn.metrics import f1_score
# π Load Evaluators Once
@st.cache_resource
def load_evaluators():
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
return rouge, bertscore
rouge, bertscore = load_evaluators()
# π Define Evaluation Functions
def evaluate_summary(generated_summary, ground_truth_summary):
"""Evaluate ROUGE and BERTScore."""
rouge_result = rouge.compute(predictions=[generated_summary], references=[ground_truth_summary])
bert_result = bertscore.compute(predictions=[generated_summary], references=[ground_truth_summary], lang="en")
return rouge_result, bert_result
def exact_match(prediction, ground_truth):
return int(prediction.strip().lower() == ground_truth.strip().lower())
def compute_bleu(prediction, ground_truth):
reference = [ground_truth.strip().split()]
candidate = prediction.strip().split()
smoothie = SmoothingFunction().method4
return sentence_bleu(reference, candidate, smoothing_function=smoothie)
def compute_f1(prediction, ground_truth):
"""Compute F1 score based on token overlap, like in QA evaluation."""
pred_tokens = prediction.strip().lower().split()
gt_tokens = ground_truth.strip().lower().split()
common_tokens = set(pred_tokens) & set(gt_tokens)
num_common = len(common_tokens)
if num_common == 0:
return 0.0
precision = num_common / len(pred_tokens)
recall = num_common / len(gt_tokens)
f1 = 2 * (precision * recall) / (precision + recall)
return f1
def evaluate_additional_metrics(prediction, ground_truth):
em = exact_match(prediction, ground_truth)
bleu = compute_bleu(prediction, ground_truth)
f1 = compute_f1(prediction, ground_truth)
return {
"Exact Match": em,
"BLEU Score": bleu,
"F1 Score": f1
}
# π₯ Upload and Evaluate
ground_truth_summary_file = st.file_uploader("π Upload Ground Truth Summary (.txt)", type=["txt"])
if ground_truth_summary_file:
ground_truth_summary = ground_truth_summary_file.read().decode("utf-8").strip()
if "generated_summary" in st.session_state and st.session_state.generated_summary:
prediction = st.session_state.generated_summary
# Evaluate ROUGE and BERTScore
rouge_result, bert_result = evaluate_summary(prediction, ground_truth_summary)
# Display ROUGE and BERTScore
st.subheader("π Evaluation Results")
st.write("πΉ ROUGE Scores:")
st.json(rouge_result)
st.write("πΉ BERTScore:")
st.json(bert_result)
# Evaluate and Display Exact Match, BLEU, F1
additional_metrics = evaluate_additional_metrics(prediction, ground_truth_summary)
st.subheader("π Additional Evaluation Metrics")
st.json(additional_metrics)
else:
st.warning("β οΈ Please generate a summary first by uploading a document.")
######################################################################################################################
# Run this along with streamlit run app.py to evaluate the model's performance on a test set
# Otherwise, comment the below code
# β EVALUATION HOOK: after the very first summary, fire off evaluate.main() once
# import json
# import pandas as pd
# import threading
# def run_eval(doc_context):
# with open("test_case1.json", "r", encoding="utf-8") as f:
# gt_data = json.load(f)
# # 2) map document_id β local file
# records = []
# for entry in gt_data:
# doc_id = entry["document_id"]
# query = entry["query"]
# gt_ans = entry["ground_truth_answer"]
# # model_ans = rag_query_response(query, emb_text)
# model_ans = rag_query_response(query, doc_context)
# records.append({
# "document_id": doc_id,
# "query": query,
# "ground_truth_answer": gt_ans,
# "model_answer": model_ans
# })
# print(f"β
Done {doc_id} / β{query}β")
# # 3) push to DataFrame + CSV
# df = pd.DataFrame(records)
# out = "evaluation_results.csv"
# df.to_csv(out, index=False, encoding="utf-8")
# print(f"\nπ Saved {len(df)} rows to {out}")
# # you could log this somewhere
# def _run_evaluation():
# try:
# run_eval()
# except Exception as e:
# print("βΌοΈ Evaluation script error:", e)
# if st.session_state.processed and not st.session_state.get("evaluation_launched", False):
# st.session_state.evaluation_launched = True
# # inform user
# st.sidebar.info("π¬ Starting background evaluation runβ¦")
# # *capture* the context
# doc_ctx = st.session_state.document_context
# # spawn the thread, passing doc_ctx in
# threading.Thread(
# target=lambda: run_eval(doc_ctx),
# daemon=True
# ).start()
# st.sidebar.success("β
Evaluation launched β check evaluation_results.csv when done.")
# # check for file existence & show download button
# eval_path = os.path.abspath("evaluation_results.csv")
# if os.path.exists(eval_path):
# st.sidebar.success(f"β
Results saved to:\n`{eval_path}`")
# # load it into a small dataframe (optional)
# df_eval = pd.read_csv(eval_path)
# # add a download button
# st.sidebar.download_button(
# label="β¬οΈ Download evaluation_results.csv",
# data=df_eval.to_csv(index=False).encode("utf-8"),
# file_name="evaluation_results.csv",
# mime="text/csv"
# )
# else:
# # if you want, display the cwd so you can inspect it
# st.sidebar.info(f"Current working dir:\n`{os.getcwd()}`")
|