|
import streamlit as st |
|
import shelve |
|
import docx2txt |
|
import PyPDF2 |
|
import time |
|
import nltk |
|
import re |
|
import os |
|
import time |
|
from dotenv import load_dotenv |
|
import torch |
|
from sentence_transformers import SentenceTransformer, util |
|
nltk.download('punkt') |
|
import hashlib |
|
from nltk import sent_tokenize |
|
nltk.download('punkt_tab') |
|
from transformers import LEDTokenizer, LEDForConditionalGeneration |
|
from transformers import pipeline |
|
import asyncio |
|
import dateutil.parser |
|
from datetime import datetime |
|
import sys |
|
|
|
from openai import OpenAI |
|
import numpy as np |
|
|
|
|
|
|
|
if sys.platform.startswith("win"): |
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) |
|
|
|
st.set_page_config(page_title="Legal Document Summarizer", layout="wide") |
|
|
|
if "processed" not in st.session_state: |
|
st.session_state.processed = False |
|
if "last_uploaded_hash" not in st.session_state: |
|
st.session_state.last_uploaded_hash = None |
|
if "chat_prompt_processed" not in st.session_state: |
|
st.session_state.chat_prompt_processed = False |
|
|
|
if "embedding_text" not in st.session_state: |
|
st.session_state.embedding_text = None |
|
|
|
if "document_context" not in st.session_state: |
|
st.session_state.document_context = None |
|
|
|
if "last_prompt_hash" not in st.session_state: |
|
st.session_state.last_prompt_hash = None |
|
|
|
|
|
st.title("π Legal Document Summarizer (Document Augmentation RAG)") |
|
|
|
USER_AVATAR = "π€" |
|
BOT_AVATAR = "π€" |
|
|
|
|
|
def load_chat_history(): |
|
with shelve.open("chat_history") as db: |
|
return db.get("messages", []) |
|
|
|
|
|
def save_chat_history(messages): |
|
with shelve.open("chat_history") as db: |
|
db["messages"] = messages |
|
|
|
|
|
def limit_text(text, word_limit=500): |
|
words = text.split() |
|
return " ".join(words[:word_limit]) + ("..." if len(words) > word_limit else "") |
|
|
|
|
|
|
|
|
|
|
|
def clean_text(text): |
|
|
|
text = text.replace('\r\n', ' ').replace('\n', ' ') |
|
text = re.sub(r'\s+', ' ', text) |
|
|
|
|
|
text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE) |
|
|
|
|
|
text = re.sub(r'[_]{5,}', '', text) |
|
text = re.sub(r'[-]{5,}', '', text) |
|
|
|
|
|
text = re.sub(r'[.]{4,}', '', text) |
|
|
|
|
|
text = text.strip() |
|
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
HF_API_TOKEN = os.getenv("HF_API_TOKEN") |
|
|
|
client = OpenAI( |
|
base_url="https://api.studio.nebius.com/v1/", |
|
api_key=os.getenv("OPENAI_API_KEY") |
|
) |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_local_zero_shot_classifier(): |
|
return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli") |
|
|
|
local_classifier = load_local_zero_shot_classifier() |
|
|
|
|
|
SECTION_LABELS = ["Facts", "Arguments", "Judgement", "Others"] |
|
|
|
def classify_chunk(text): |
|
result = local_classifier(text, candidate_labels=SECTION_LABELS) |
|
return result["labels"][0] |
|
|
|
|
|
|
|
def section_by_zero_shot(text): |
|
sections = {"Facts": "", "Arguments": "", "Judgment": "", "Others": ""} |
|
sentences = sent_tokenize(text) |
|
chunk = "" |
|
|
|
for i, sent in enumerate(sentences): |
|
chunk += sent + " " |
|
if (i + 1) % 3 == 0 or i == len(sentences) - 1: |
|
label = classify_chunk(chunk.strip()) |
|
print(f"π Chunk: {chunk[:60]}...\nπ Predicted Label: {label}") |
|
|
|
label = label.capitalize() |
|
if label not in sections: |
|
label = "Others" |
|
sections[label] += chunk + "\n" |
|
chunk = "" |
|
|
|
return sections |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_text(file): |
|
if file.name.endswith(".pdf"): |
|
reader = PyPDF2.PdfReader(file) |
|
full_text = "\n".join(page.extract_text() or "" for page in reader.pages) |
|
elif file.name.endswith(".docx"): |
|
full_text = docx2txt.process(file) |
|
elif file.name.endswith(".txt"): |
|
full_text = file.read().decode("utf-8") |
|
else: |
|
return "Unsupported file type." |
|
|
|
return full_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_legalbert(): |
|
return SentenceTransformer("nlpaueb/legal-bert-base-uncased") |
|
|
|
|
|
legalbert_model = load_legalbert() |
|
|
|
@st.cache_resource |
|
def load_led(): |
|
tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384") |
|
model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384") |
|
return tokenizer, model |
|
|
|
tokenizer_led, model_led = load_led() |
|
|
|
|
|
def legalbert_extractive_summary(text, top_ratio=0.2): |
|
sentences = sent_tokenize(text) |
|
top_k = max(3, int(len(sentences) * top_ratio)) |
|
if len(sentences) <= top_k: |
|
return text |
|
sentence_embeddings = legalbert_model.encode(sentences, convert_to_tensor=True) |
|
doc_embedding = torch.mean(sentence_embeddings, dim=0) |
|
cosine_scores = util.pytorch_cos_sim(doc_embedding, sentence_embeddings)[0] |
|
top_results = torch.topk(cosine_scores, k=top_k) |
|
selected_sentences = [sentences[i] for i in sorted(top_results.indices.tolist())] |
|
return " ".join(selected_sentences) |
|
|
|
|
|
|
|
|
|
def led_abstractive_summary(text, max_length=512, min_length=100): |
|
inputs = tokenizer_led( |
|
text, return_tensors="pt", padding="max_length", |
|
truncation=True, max_length=4096 |
|
) |
|
global_attention_mask = torch.zeros_like(inputs["input_ids"]) |
|
global_attention_mask[:, 0] = 1 |
|
|
|
outputs = model_led.generate( |
|
inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
global_attention_mask=global_attention_mask, |
|
max_length=max_length, |
|
min_length=min_length, |
|
num_beams=4, |
|
repetition_penalty=2.0, |
|
length_penalty=1.0, |
|
early_stopping=True, |
|
no_repeat_ngram_size=4 |
|
) |
|
|
|
return tokenizer_led.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
def led_abstractive_summary_chunked(text, max_tokens=3000): |
|
sentences = sent_tokenize(text) |
|
current_chunk, chunks, summaries = "", [], [] |
|
for sent in sentences: |
|
if len(tokenizer_led(current_chunk + sent)["input_ids"]) > max_tokens: |
|
chunks.append(current_chunk) |
|
current_chunk = sent |
|
else: |
|
current_chunk += " " + sent |
|
if current_chunk: |
|
chunks.append(current_chunk) |
|
for chunk in chunks: |
|
inputs = tokenizer_led(chunk, return_tensors="pt", padding="max_length", truncation=True, max_length=4096) |
|
global_attention_mask = torch.zeros_like(inputs["input_ids"]) |
|
global_attention_mask[:, 0] = 1 |
|
output = model_led.generate( |
|
inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
global_attention_mask=global_attention_mask, |
|
max_length=512, |
|
min_length=100, |
|
num_beams=4, |
|
repetition_penalty=2.0, |
|
length_penalty=1.0, |
|
early_stopping=True, |
|
no_repeat_ngram_size=4, |
|
) |
|
summaries.append(tokenizer_led.decode(output[0], skip_special_tokens=True)) |
|
return " ".join(summaries) |
|
|
|
|
|
|
|
def hybrid_summary_hierarchical(text, top_ratio=0.8): |
|
cleaned_text = clean_text(text) |
|
sections = section_by_zero_shot(cleaned_text) |
|
|
|
structured_summary = {} |
|
|
|
for name, content in sections.items(): |
|
if content.strip(): |
|
|
|
extractive = legalbert_extractive_summary(content, top_ratio) |
|
|
|
|
|
abstractive = led_abstractive_summary_chunked(extractive) |
|
|
|
|
|
structured_summary[name] = { |
|
"extractive": extractive, |
|
"abstractive": abstractive |
|
} |
|
|
|
return structured_summary |
|
|
|
|
|
def chunk_text_custom(text, n=1000, overlap=200): |
|
chunks = [] |
|
for i in range(0, len(text), n - overlap): |
|
chunks.append(text[i:i + n]) |
|
return chunks |
|
|
|
|
|
|
|
def get_embedding(text, model="BAAI/bge-en-icl"): |
|
""" |
|
From your notebook: |
|
Creates an embedding for the given text chunk using the BGE-ICL model. |
|
""" |
|
resp = client.embeddings.create(model=model, input=text) |
|
return np.array(resp.data[0].embedding) |
|
|
|
|
|
|
|
def semantic_search(query, text_chunks, chunk_embeddings, k=5): |
|
""" |
|
Compute cosine similarity between the query embedding and each chunk embedding, |
|
then pick the top-k chunks. |
|
""" |
|
q_emb = get_embedding(query) |
|
|
|
def cosine(a, b): return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) |
|
scores = [cosine(q_emb, emb) for emb in chunk_embeddings] |
|
top_idxs = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k] |
|
return [text_chunks[i] for i in top_idxs] |
|
|
|
|
|
def generate_response(system_prompt, user_message, model="meta-llama/Llama-3.2-3B-Instruct"): |
|
return client.chat.completions.create( |
|
model=model, |
|
temperature=0, |
|
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}] |
|
).choices[0].message.content |
|
|
|
|
|
def generate_questions(text_chunk, num_questions=5, |
|
model="meta-llama/Llama-3.2-3B-Instruct"): |
|
system_prompt = ( |
|
"You are an expert at generating relevant questions from text. " |
|
"Create concise questions that can be answered using only the provided text." |
|
) |
|
user_prompt = f""" |
|
Based on the following text, generate {num_questions} different questions |
|
that can be answered using only this text: |
|
|
|
{text_chunk} |
|
|
|
Format your response as a numbered list of questions only. |
|
""" |
|
resp = client.chat.completions.create( |
|
model=model, |
|
temperature=0.7, |
|
messages=[ |
|
{"role":"system","content":system_prompt}, |
|
{"role":"user","content":user_prompt} |
|
] |
|
) |
|
raw = resp.choices[0].message.content.strip() |
|
questions = [] |
|
for line in raw.split("\n"): |
|
q = re.sub(r"^\d+\.\s*", "", line).strip() |
|
if q.endswith("?"): |
|
questions.append(q) |
|
return questions |
|
|
|
|
|
def create_embeddings(text, model="BAAI/bge-en-icl"): |
|
resp = client.embeddings.create(model=model, input=text) |
|
return resp.data[0].embedding |
|
|
|
def cosine_similarity(a,b): |
|
return float(np.dot(a,b)/(np.linalg.norm(a)*np.linalg.norm(b))) |
|
|
|
|
|
class SimpleVectorStore: |
|
def __init__(self): |
|
self.items = [] |
|
def add_item(self, text, embedding, metadata): |
|
self.items.append(dict(text=text, embedding=embedding, metadata=metadata)) |
|
def search(self, query, k=5): |
|
q_emb = create_embeddings(query) |
|
scores = [(i, cosine_similarity(q_emb, item["embedding"])) |
|
for i,item in enumerate(self.items)] |
|
scores.sort(key=lambda x:x[1], reverse=True) |
|
return [self.items[i] for i,_ in scores[:k]] |
|
|
|
|
|
def process_document(raw_text, |
|
chunk_size=1000, |
|
chunk_overlap=200, |
|
questions_per_chunk=5): |
|
|
|
chunks = [] |
|
for i in range(0, len(raw_text), chunk_size - chunk_overlap): |
|
chunks.append(raw_text[i:i+chunk_size]) |
|
store = SimpleVectorStore() |
|
for idx,chunk in enumerate(chunks): |
|
|
|
emb = create_embeddings(chunk) |
|
store.add_item(chunk, emb, {"type":"chunk","index":idx}) |
|
|
|
qs = generate_questions(chunk, num_questions=questions_per_chunk) |
|
for q in qs: |
|
q_emb = create_embeddings(q) |
|
store.add_item(q, q_emb, { |
|
"type":"question", |
|
"chunk_index":idx, |
|
"original_chunk": chunk |
|
}) |
|
return chunks, store |
|
|
|
|
|
def prepare_context(results): |
|
seen = set() |
|
ctx = [] |
|
|
|
for r in results: |
|
m = r["metadata"] |
|
if m["type"]=="chunk" and m["index"] not in seen: |
|
seen.add(m["index"]) |
|
ctx.append(f"Chunk {m['index']}:\n{r['text']}") |
|
|
|
for r in results: |
|
m = r["metadata"] |
|
if m["type"]=="question": |
|
ci = m["chunk_index"] |
|
if ci not in seen: |
|
seen.add(ci) |
|
ctx.append(f"Chunk {ci} (via Q β{r['text']}β):\n{m['original_chunk']}") |
|
return "\n\n".join(ctx) |
|
|
|
|
|
def generate_response_from_context(query, context, |
|
model="meta-llama/Llama-3.2-3B-Instruct"): |
|
sp = ( |
|
"You are an AI assistant that strictly answers based on the given context. " |
|
"If the answer cannot be derived directly from the provided context, " |
|
"respond with: 'I do not have enough information to answer that.'" |
|
) |
|
up = f""" |
|
Context: |
|
{context} |
|
|
|
Question: {query} |
|
|
|
Please answer the question based only on the context above. |
|
""" |
|
resp = client.chat.completions.create( |
|
model=model, |
|
temperature=0, |
|
messages=[{"role":"system","content":sp}, |
|
{"role":"user","content":up}] |
|
) |
|
return resp.choices[0].message.content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = load_chat_history() |
|
|
|
|
|
if "last_uploaded" not in st.session_state: |
|
st.session_state.last_uploaded = None |
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
st.subheader("βοΈ Options") |
|
if st.button("Delete Chat History"): |
|
st.session_state.messages = [] |
|
st.session_state.last_uploaded = None |
|
st.session_state.processed = False |
|
st.session_state.chat_prompt_processed = False |
|
save_chat_history([]) |
|
|
|
|
|
|
|
def display_with_typing_effect(text, speed=0.005): |
|
placeholder = st.empty() |
|
displayed_text = "" |
|
for char in text: |
|
displayed_text += char |
|
placeholder.markdown(displayed_text) |
|
time.sleep(speed) |
|
return displayed_text |
|
|
|
|
|
for message in st.session_state.messages: |
|
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR |
|
with st.chat_message(message["role"], avatar=avatar): |
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
prompt = st.chat_input("Type a message...") |
|
|
|
|
|
|
|
with st.container(): |
|
st.subheader("π Upload a Legal Document") |
|
uploaded_file = st.file_uploader("Upload a file (PDF, DOCX, TXT)", type=["pdf", "docx", "txt"]) |
|
reprocess_btn = st.button("π Reprocess Last Uploaded File") |
|
|
|
|
|
|
|
|
|
def get_file_hash(file): |
|
file.seek(0) |
|
content = file.read() |
|
file.seek(0) |
|
return hashlib.md5(content).hexdigest() |
|
|
|
|
|
|
|
def prepare_text_for_embedding(summary_dict): |
|
combined_chunks = [] |
|
|
|
for section, content in summary_dict.items(): |
|
ext = content.get("extractive", "").strip() |
|
abs = content.get("abstractive", "").strip() |
|
if ext: |
|
combined_chunks.append(f"{section} - Extractive Summary:\n{ext}") |
|
if abs: |
|
combined_chunks.append(f"{section} - Abstractive Summary:\n{abs}") |
|
|
|
return "\n\n".join(combined_chunks) |
|
|
|
|
|
|
|
|
|
user_role = st.sidebar.selectbox( |
|
"π Select Your Role for Custom Summary", |
|
["General", "Judge", "Lawyer", "Student"] |
|
) |
|
|
|
|
|
def role_based_filter(section, summary, role): |
|
if role == "General": |
|
return summary |
|
|
|
filtered_summary = { |
|
"extractive": "", |
|
"abstractive": "" |
|
} |
|
|
|
if role == "Judge" and section in ["Judgement", "Facts"]: |
|
filtered_summary = summary |
|
elif role == "Lawyer" and section in ["Arguments", "Facts"]: |
|
filtered_summary = summary |
|
elif role == "Student" and section in ["Facts"]: |
|
filtered_summary = summary |
|
|
|
return filtered_summary |
|
|
|
|
|
|
|
|
|
|
|
|
|
if uploaded_file: |
|
file_hash = get_file_hash(uploaded_file) |
|
if file_hash != st.session_state.last_uploaded_hash or reprocess_btn: |
|
st.session_state.processed = False |
|
|
|
if not st.session_state.processed: |
|
start_time = time.time() |
|
|
|
|
|
raw_text = extract_text(uploaded_file) |
|
summary_dict = hybrid_summary_hierarchical(raw_text) |
|
embedding_text = prepare_text_for_embedding(summary_dict) |
|
|
|
|
|
chunks, store = process_document(raw_text, |
|
chunk_size=1000, |
|
chunk_overlap=200, |
|
questions_per_chunk=5) |
|
st.session_state.vector_store = store |
|
|
|
|
|
|
|
st.session_state.document_context = embedding_text |
|
|
|
if user_role == "General": |
|
role_specific_prompt = ( |
|
"Summarize the legal document focusing on the most relevant aspects " |
|
"such as facts, arguments, and judgments. Include key legal reasoning " |
|
"and a timeline of events where necessary." |
|
) |
|
else: |
|
role_specific_prompt = ( |
|
f"As a {user_role}, summarize the legal document focusing on " |
|
"the most relevant aspects such as facts, arguments, and judgments " |
|
"tailored for your role. Include key legal reasoning and timeline of events." |
|
) |
|
|
|
|
|
results = store.search(role_specific_prompt, k=5) |
|
context = prepare_context(results) |
|
rag_summary = generate_response_from_context(role_specific_prompt, context) |
|
|
|
|
|
st.session_state.messages.append({ |
|
"role": "user", |
|
"content": f"π€ Uploaded **{uploaded_file.name}**" |
|
}) |
|
st.session_state.messages.append({ |
|
"role": "assistant", |
|
"content": rag_summary |
|
}) |
|
with st.chat_message("assistant", avatar=BOT_AVATAR): |
|
display_with_typing_effect(rag_summary) |
|
|
|
processing_time = round((time.time() - start_time) / 60, 2) |
|
st.info(f"β±οΈ Response generated in **{processing_time} minutes**.") |
|
|
|
st.session_state.generated_summary = rag_summary |
|
st.session_state.last_uploaded_hash = file_hash |
|
st.session_state.processed = True |
|
st.session_state.last_prompt_hash = None |
|
save_chat_history(st.session_state.messages) |
|
|
|
|
|
|
|
if prompt: |
|
words = prompt.split() |
|
word_count = len(words) |
|
prompt_hash = hashlib.md5(prompt.encode("utf-8")).hexdigest() |
|
|
|
|
|
if word_count > 30 and prompt_hash != st.session_state.last_prompt_hash: |
|
st.session_state.last_prompt_hash = prompt_hash |
|
|
|
raw_text = prompt |
|
st.session_state.messages.append({ |
|
"role": "user", |
|
"content": f"π₯ **Pasted Document Text:**\n\n{limit_text(raw_text,500)}" |
|
}) |
|
with st.chat_message("user", avatar=USER_AVATAR): |
|
st.markdown(limit_text(raw_text,500)) |
|
|
|
start_time = time.time() |
|
|
|
summary_dict = hybrid_summary_hierarchical(raw_text) |
|
emb_text = prepare_text_for_embedding(summary_dict) |
|
st.session_state.document_context = emb_text |
|
st.session_state.processed = True |
|
|
|
|
|
chunks, store = process_document(raw_text) |
|
st.session_state.vector_store = store |
|
|
|
if user_role == "General": |
|
role_prompt = ( |
|
"Summarize the document focusing on facts, arguments, judgments, " |
|
"and include a timeline of events." |
|
) |
|
else: |
|
role_prompt = ( |
|
f"As a {user_role}, summarize the document focusing on facts, " |
|
"arguments, judgments, plus timeline of events." |
|
) |
|
|
|
|
|
results = store.search(role_prompt, k=5) |
|
context = prepare_context(results) |
|
initial_summary = generate_response_from_context(role_prompt, context) |
|
|
|
st.session_state.messages.append({ |
|
"role": "assistant", |
|
"content": initial_summary |
|
}) |
|
with st.chat_message("assistant", avatar=BOT_AVATAR): |
|
display_with_typing_effect(initial_summary) |
|
|
|
st.info(f"β±οΈ Summary generated in {round((time.time()-start_time)/60,2)} minutes") |
|
save_chat_history(st.session_state.messages) |
|
|
|
|
|
|
|
elif word_count <= 30 and st.session_state.processed: |
|
|
|
with st.chat_message("user", avatar=USER_AVATAR): |
|
st.markdown(prompt) |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
store = st.session_state.vector_store |
|
|
|
|
|
results = store.search(prompt, k=5) |
|
context = prepare_context(results) |
|
answer = generate_response_from_context(prompt, context) |
|
|
|
|
|
st.session_state.messages.append({"role":"assistant","content":answer}) |
|
with st.chat_message("assistant", avatar=BOT_AVATAR): |
|
display_with_typing_effect(answer) |
|
save_chat_history(st.session_state.messages) |
|
|
|
|
|
|
|
else: |
|
with st.chat_message("assistant", avatar=BOT_AVATAR): |
|
st.markdown("β Paste at least 30 words of your document to ingest it first.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
import evaluate |
|
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction |
|
from sklearn.metrics import f1_score |
|
|
|
|
|
@st.cache_resource |
|
def load_evaluators(): |
|
rouge = evaluate.load("rouge") |
|
bertscore = evaluate.load("bertscore") |
|
return rouge, bertscore |
|
|
|
rouge, bertscore = load_evaluators() |
|
|
|
|
|
def evaluate_summary(generated_summary, ground_truth_summary): |
|
"""Evaluate ROUGE and BERTScore.""" |
|
rouge_result = rouge.compute(predictions=[generated_summary], references=[ground_truth_summary]) |
|
bert_result = bertscore.compute(predictions=[generated_summary], references=[ground_truth_summary], lang="en") |
|
return rouge_result, bert_result |
|
|
|
def exact_match(prediction, ground_truth): |
|
return int(prediction.strip().lower() == ground_truth.strip().lower()) |
|
|
|
def compute_bleu(prediction, ground_truth): |
|
reference = [ground_truth.strip().split()] |
|
candidate = prediction.strip().split() |
|
smoothie = SmoothingFunction().method4 |
|
return sentence_bleu(reference, candidate, smoothing_function=smoothie) |
|
|
|
def compute_f1(prediction, ground_truth): |
|
"""Compute F1 score based on token overlap, like in QA evaluation.""" |
|
pred_tokens = prediction.strip().lower().split() |
|
gt_tokens = ground_truth.strip().lower().split() |
|
|
|
common_tokens = set(pred_tokens) & set(gt_tokens) |
|
num_common = len(common_tokens) |
|
|
|
if num_common == 0: |
|
return 0.0 |
|
|
|
precision = num_common / len(pred_tokens) |
|
recall = num_common / len(gt_tokens) |
|
f1 = 2 * (precision * recall) / (precision + recall) |
|
return f1 |
|
|
|
def evaluate_additional_metrics(prediction, ground_truth): |
|
em = exact_match(prediction, ground_truth) |
|
bleu = compute_bleu(prediction, ground_truth) |
|
f1 = compute_f1(prediction, ground_truth) |
|
return { |
|
"Exact Match": em, |
|
"BLEU Score": bleu, |
|
"F1 Score": f1 |
|
} |
|
|
|
|
|
ground_truth_summary_file = st.file_uploader("π Upload Ground Truth Summary (.txt)", type=["txt"]) |
|
|
|
if ground_truth_summary_file: |
|
ground_truth_summary = ground_truth_summary_file.read().decode("utf-8").strip() |
|
|
|
if "generated_summary" in st.session_state and st.session_state.generated_summary: |
|
prediction = st.session_state.generated_summary |
|
|
|
|
|
rouge_result, bert_result = evaluate_summary(prediction, ground_truth_summary) |
|
|
|
|
|
st.subheader("π Evaluation Results") |
|
st.write("πΉ ROUGE Scores:") |
|
st.json(rouge_result) |
|
st.write("πΉ BERTScore:") |
|
st.json(bert_result) |
|
|
|
|
|
additional_metrics = evaluate_additional_metrics(prediction, ground_truth_summary) |
|
st.subheader("π Additional Evaluation Metrics") |
|
st.json(additional_metrics) |
|
|
|
else: |
|
st.warning("β οΈ Please generate a summary first by uploading a document.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|