Spaces:
Sleeping
Sleeping
import itertools | |
import pickle | |
# Import and download necessary NLTK data for tokenization. | |
import nltk | |
from nltk.translate.bleu_score import sentence_bleu | |
nltk.download('punkt') | |
# Import the ROUGE metric implementation. | |
from rouge import Rouge | |
rouge = Rouge() | |
from datasets import load_dataset | |
import streamlit as st | |
# Use name="sample-10BT" to use the 10BT sample. | |
fw = load_dataset("HuggingFaceFW/fineweb", name="CC-MAIN-2024-10", split="train", streaming=True) | |
# Define helper functions for character-level accuracy and precision. | |
def char_accuracy(true_output, model_output): | |
# Compare matching characters in corresponding positions. | |
matches = sum(1 for c1, c2 in zip(true_output, model_output) if c1 == c2) | |
# Account for any extra characters in either string. | |
total = max(len(true_output), len(model_output)) | |
return matches / total if total > 0 else 1.0 | |
def char_precision(true_output, model_output): | |
# Precision is matching characters divided by the length of the model's output. | |
matches = sum(1 for c1, c2 in zip(true_output, model_output) if c1 == c2) | |
return matches / len(model_output) if len(model_output) > 0 else 0.0 | |
# Initialize Streamlit app | |
st.title("Model Evaluation App") | |
st.write("This app evaluates a model's ability to reverse input text character by character.") | |
# Parameters | |
word_threshold = st.sidebar.number_input("Word Threshold", value=100, step=10) | |
num_samples = st.sidebar.number_input("Number of Samples", value=1, step=1) | |
# Get samples | |
samples = list(itertools.islice(fw, num_samples)) | |
acc = [] | |
pres = [] | |
bleu = [] | |
rouges = [] | |
for x in samples: | |
nextt = x["text"].split(" ") | |
for n in range(len(nextt) // word_threshold): | |
inp = nextt[word_threshold * n: word_threshold * (n + 1)] | |
inp = " ".join(inp).replace("\n", "") | |
# Display the input text | |
st.subheader("Input Text") | |
st.write(inp) | |
prompt = ( | |
"You are a helpful assistant that echoes the user's input, but backwards, " | |
"do not simply rearrange the words, reverse the user's input down to the character " | |
"(e.g. reverse Hello World to dlroW olleH). Surround the backwards version of the " | |
"user's input with <back> </back> tags. " + inp | |
) | |
# Ground truth: reverse the input (character by character) | |
true_output = inp[::-1] | |
st.subheader("True Output") | |
st.write(true_output) | |
# Get the model output (Here, we simulate it or integrate your model inference) | |
# For demonstration purposes, we'll reverse the input as the model output | |
# Replace this part with your model's actual output | |
model_output_full = st.text_input("Model Ouput:", "") | |
# Extract the text between <back> and </back> tags | |
tag1 = model_output_full.find("<back>") | |
tag2 = model_output_full.find("</back>") | |
model_output = model_output_full[tag1 + 6: tag2] | |
st.subheader("Model Output") | |
st.write(model_output) | |
# Tokenize both outputs for BLEU calculation | |
reference_tokens = nltk.word_tokenize(true_output) | |
candidate_tokens = nltk.word_tokenize(model_output) | |
# Compute BLEU score (using the single reference) | |
bleu_score = sentence_bleu([reference_tokens], candidate_tokens) | |
st.write("**BLEU Score:**", bleu_score) | |
# Compute ROUGE scores | |
rouge_scores = rouge.get_scores(model_output, true_output) | |
st.write("**ROUGE Scores:**") | |
st.json(rouge_scores) | |
# Compute character-level accuracy and precision | |
accuracy_metric = char_accuracy(true_output, model_output) | |
precision_metric = char_precision(true_output, model_output) | |
st.write("**Character Accuracy:**", accuracy_metric) | |
st.write("**Character Precision:**", precision_metric) | |
st.markdown("---") | |
# Append metrics to lists | |
acc.append(accuracy_metric) | |
pres.append(precision_metric) | |
bleu.append(bleu_score) | |
rouges.append(rouge_scores) | |
# Allow the user to download the metrics | |
if st.button("Download Metrics"): | |
with open('accuracy.pkl', 'wb') as file: | |
pickle.dump(acc, file) | |
with open('precision.pkl', 'wb') as file: | |
pickle.dump(pres, file) | |
with open('bleu.pkl', 'wb') as file: | |
pickle.dump(bleu, file) | |
with open('rouge.pkl', 'wb') as file: | |
pickle.dump(rouges, file) | |
st.success("Metrics saved successfully!") | |
# Provide download links | |
st.download_button('Download Accuracy Metrics', data=open('accuracy.pkl', 'rb'), file_name='accuracy.pkl') | |
st.download_button('Download Precision Metrics', data=open('precision.pkl', 'rb'), file_name='precision.pkl') | |
st.download_button('Download BLEU Metrics', data=open('bleu.pkl', 'rb'), file_name='bleu.pkl') | |
st.download_button('Download ROUGE Metrics', data=open('rouge.pkl', 'rb'), file_name='rouge.pkl') | |