import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig from peft import PeftModel def main(): st.title("Math Meme Repair (LoRA-Fine-Tuned)") st.markdown(""" **Instructions**: 1. Enter your incorrect math meme in the format: ``` Math Meme Correction: Incorrect: 5-3-1 = 3? Correct: ``` 2. Click **Repair Math Meme** to generate a corrected explanation. **Note**: This is running on CPU, so it may be slow and memory-intensive for a 7B model. """) # 1. Load the base model from Hugging Face model_name = "deepseek-ai/deepseek-math-7b-base" tokenizer = AutoTokenizer.from_pretrained(model_name) # If your CPU doesn't support float16, switch to float32. # (float16 might not work well on certain CPUs) base_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32 # CPU-friendly dtype ) base_model = base_model.to("cpu") # We'll run on CPU # 2. Load your LoRA adapter (local directory with adapter_config.json & adapter_model.safetensors) adapter_dir = "trained-math-meme-model" model = PeftModel.from_pretrained(base_model, adapter_dir) model = model.to("cpu") # 3. Configure generation generation_config = GenerationConfig( max_new_tokens=100, temperature=0.7, top_p=0.7, pad_token_id=tokenizer.eos_token_id ) # 4. User input area user_input = st.text_area( "Enter your math meme input:", value="Math Meme Correction:\nIncorrect: 5-3-1 = 3?\nCorrect:" ) if st.button("Repair Math Meme"): if user_input.strip() == "": st.warning("Please enter a math meme input following the required format.") else: with torch.no_grad(): # Tokenize on CPU encoding = tokenizer(user_input, return_tensors="pt").to("cpu") outputs = model.generate( input_ids=encoding.input_ids, attention_mask=encoding.attention_mask, max_new_tokens=generation_config.max_new_tokens, temperature=generation_config.temperature, top_p=generation_config.top_p, pad_token_id=generation_config.pad_token_id ) # Decode and display result = tokenizer.decode(outputs[0], skip_special_tokens=True) st.subheader("Repaired Math Meme") st.write(result) st.markdown("\n**Error Rating:** 90% sass, 10% patience (on CPU)") if __name__ == "__main__": main()