|
import streamlit as st |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig |
|
from peft import PeftModel |
|
|
|
def main(): |
|
st.title("Math Meme Repair (LoRA-Fine-Tuned)") |
|
|
|
st.markdown(""" |
|
**Instructions**: |
|
1. Enter your incorrect math meme in the format: |
|
``` |
|
Math Meme Correction: |
|
Incorrect: 5-3-1 = 3? |
|
Correct: |
|
``` |
|
2. Click **Repair Math Meme** to generate a corrected explanation. |
|
|
|
**Note**: This is running on CPU, so it may be slow and memory-intensive for a 7B model. |
|
""") |
|
|
|
|
|
model_name = "deepseek-ai/deepseek-math-7b-base" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float32 |
|
) |
|
base_model = base_model.to("cpu") |
|
|
|
|
|
adapter_dir = "trained-math-meme-model" |
|
model = PeftModel.from_pretrained(base_model, adapter_dir) |
|
model = model.to("cpu") |
|
|
|
|
|
generation_config = GenerationConfig( |
|
max_new_tokens=100, |
|
temperature=0.7, |
|
top_p=0.7, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
user_input = st.text_area( |
|
"Enter your math meme input:", |
|
value="Math Meme Correction:\nIncorrect: 5-3-1 = 3?\nCorrect:" |
|
) |
|
|
|
if st.button("Repair Math Meme"): |
|
if user_input.strip() == "": |
|
st.warning("Please enter a math meme input following the required format.") |
|
else: |
|
with torch.no_grad(): |
|
|
|
encoding = tokenizer(user_input, return_tensors="pt").to("cpu") |
|
outputs = model.generate( |
|
input_ids=encoding.input_ids, |
|
attention_mask=encoding.attention_mask, |
|
max_new_tokens=generation_config.max_new_tokens, |
|
temperature=generation_config.temperature, |
|
top_p=generation_config.top_p, |
|
pad_token_id=generation_config.pad_token_id |
|
) |
|
|
|
|
|
result = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
st.subheader("Repaired Math Meme") |
|
st.write(result) |
|
|
|
st.markdown("\n**Error Rating:** 90% sass, 10% patience (on CPU)") |
|
|
|
if __name__ == "__main__": |
|
main() |