Inoob's picture
Update app.py
b32a0cf verified
raw
history blame
4.96 kB
import itertools
import pickle
# Import and download necessary NLTK data for tokenization.
import nltk
from nltk.translate.bleu_score import sentence_bleu
nltk.download('punkt')
# Import the ROUGE metric implementation.
from rouge import Rouge
rouge = Rouge()
from datasets import load_dataset
import streamlit as st
# Use name="sample-10BT" to use the 10BT sample.
fw = load_dataset("HuggingFaceFW/fineweb", name="CC-MAIN-2024-10", split="train", streaming=True)
# Define helper functions for character-level accuracy and precision.
def char_accuracy(true_output, model_output):
# Compare matching characters in corresponding positions.
matches = sum(1 for c1, c2 in zip(true_output, model_output) if c1 == c2)
# Account for any extra characters in either string.
total = max(len(true_output), len(model_output))
return matches / total if total > 0 else 1.0
def char_precision(true_output, model_output):
# Precision is matching characters divided by the length of the model's output.
matches = sum(1 for c1, c2 in zip(true_output, model_output) if c1 == c2)
return matches / len(model_output) if len(model_output) > 0 else 0.0
# Initialize Streamlit app
st.title("Model Evaluation App")
st.write("This app evaluates a model's ability to reverse input text character by character.")
# Parameters
word_threshold = st.sidebar.number_input("Word Threshold", value=100, step=10)
num_samples = st.sidebar.number_input("Number of Samples", value=1, step=1)
# Get samples
samples = list(itertools.islice(fw, num_samples))
acc = []
pres = []
bleu = []
rouges = []
for x in samples:
nextt = x["text"].split(" ")
for n in range(len(nextt) // word_threshold):
inp = nextt[word_threshold * n: word_threshold * (n + 1)]
inp = " ".join(inp).replace("\n", "")
# Display the input text
st.subheader("Input Text")
st.write(inp)
prompt = (
"You are a helpful assistant that echoes the user's input, but backwards, "
"do not simply rearrange the words, reverse the user's input down to the character "
"(e.g. reverse Hello World to dlroW olleH). Surround the backwards version of the "
"user's input with <back> </back> tags. " + inp
)
# Ground truth: reverse the input (character by character)
true_output = inp[::-1]
st.subheader("True Output")
st.write(true_output)
# Get the model output (Here, we simulate it or integrate your model inference)
# For demonstration purposes, we'll reverse the input as the model output
# Replace this part with your model's actual output
model_output_full = st.text_input("Model Ouput:", "")
# Extract the text between <back> and </back> tags
tag1 = model_output_full.find("<back>")
tag2 = model_output_full.find("</back>")
model_output = model_output_full[tag1 + 6: tag2]
st.subheader("Model Output")
st.write(model_output)
# Tokenize both outputs for BLEU calculation
reference_tokens = nltk.word_tokenize(true_output)
candidate_tokens = nltk.word_tokenize(model_output)
# Compute BLEU score (using the single reference)
bleu_score = sentence_bleu([reference_tokens], candidate_tokens)
st.write("**BLEU Score:**", bleu_score)
# Compute ROUGE scores
rouge_scores = rouge.get_scores(model_output, true_output)
st.write("**ROUGE Scores:**")
st.json(rouge_scores)
# Compute character-level accuracy and precision
accuracy_metric = char_accuracy(true_output, model_output)
precision_metric = char_precision(true_output, model_output)
st.write("**Character Accuracy:**", accuracy_metric)
st.write("**Character Precision:**", precision_metric)
st.markdown("---")
# Append metrics to lists
acc.append(accuracy_metric)
pres.append(precision_metric)
bleu.append(bleu_score)
rouges.append(rouge_scores)
# Allow the user to download the metrics
if st.button("Download Metrics"):
with open('accuracy.pkl', 'wb') as file:
pickle.dump(acc, file)
with open('precision.pkl', 'wb') as file:
pickle.dump(pres, file)
with open('bleu.pkl', 'wb') as file:
pickle.dump(bleu, file)
with open('rouge.pkl', 'wb') as file:
pickle.dump(rouges, file)
st.success("Metrics saved successfully!")
# Provide download links
st.download_button('Download Accuracy Metrics', data=open('accuracy.pkl', 'rb'), file_name='accuracy.pkl')
st.download_button('Download Precision Metrics', data=open('precision.pkl', 'rb'), file_name='precision.pkl')
st.download_button('Download BLEU Metrics', data=open('bleu.pkl', 'rb'), file_name='bleu.pkl')
st.download_button('Download ROUGE Metrics', data=open('rouge.pkl', 'rb'), file_name='rouge.pkl')