MMC / app.py
AshenBorn's picture
Update app.py
9a1e940 verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from peft import PeftModel
def main():
st.title("Math Meme Repair (LoRA-Fine-Tuned)")
st.markdown("""
**Instructions**:
1. Enter your incorrect math meme in the format:
```
Math Meme Correction:
Incorrect: 5-3-1 = 3?
Correct:
```
2. Click **Repair Math Meme** to generate a corrected explanation.
**Note**: This is running on CPU, so it may be slow and memory-intensive for a 7B model.
""")
# 1. Load the base model from Hugging Face
model_name = "deepseek-ai/deepseek-math-7b-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# If your CPU doesn't support float16, switch to float32.
# (float16 might not work well on certain CPUs)
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32 # CPU-friendly dtype
)
base_model = base_model.to("cpu") # We'll run on CPU
# 2. Load your LoRA adapter (local directory with adapter_config.json & adapter_model.safetensors)
adapter_dir = "trained-math-meme-model"
model = PeftModel.from_pretrained(base_model, adapter_dir)
model = model.to("cpu")
# 3. Configure generation
generation_config = GenerationConfig(
max_new_tokens=100,
temperature=0.7,
top_p=0.7,
pad_token_id=tokenizer.eos_token_id
)
# 4. User input area
user_input = st.text_area(
"Enter your math meme input:",
value="Math Meme Correction:\nIncorrect: 5-3-1 = 3?\nCorrect:"
)
if st.button("Repair Math Meme"):
if user_input.strip() == "":
st.warning("Please enter a math meme input following the required format.")
else:
with torch.no_grad():
# Tokenize on CPU
encoding = tokenizer(user_input, return_tensors="pt").to("cpu")
outputs = model.generate(
input_ids=encoding.input_ids,
attention_mask=encoding.attention_mask,
max_new_tokens=generation_config.max_new_tokens,
temperature=generation_config.temperature,
top_p=generation_config.top_p,
pad_token_id=generation_config.pad_token_id
)
# Decode and display
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.subheader("Repaired Math Meme")
st.write(result)
st.markdown("\n**Error Rating:** 90% sass, 10% patience (on CPU)")
if __name__ == "__main__":
main()