File size: 2,699 Bytes
f05fc7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a1e940
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from peft import PeftModel

def main():
    st.title("Math Meme Repair (LoRA-Fine-Tuned)")

    st.markdown("""
    **Instructions**:
    1. Enter your incorrect math meme in the format:
       ```
       Math Meme Correction:
       Incorrect: 5-3-1 = 3?
       Correct:
       ```
    2. Click **Repair Math Meme** to generate a corrected explanation.

    **Note**: This is running on CPU, so it may be slow and memory-intensive for a 7B model.
    """)

    # 1. Load the base model from Hugging Face
    model_name = "deepseek-ai/deepseek-math-7b-base"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # If your CPU doesn't support float16, switch to float32.
    # (float16 might not work well on certain CPUs)
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float32  # CPU-friendly dtype
    )
    base_model = base_model.to("cpu")  # We'll run on CPU

    # 2. Load your LoRA adapter (local directory with adapter_config.json & adapter_model.safetensors)
    adapter_dir = "trained-math-meme-model"
    model = PeftModel.from_pretrained(base_model, adapter_dir)
    model = model.to("cpu")

    # 3. Configure generation
    generation_config = GenerationConfig(
        max_new_tokens=100,
        temperature=0.7,
        top_p=0.7,
        pad_token_id=tokenizer.eos_token_id
    )

    # 4. User input area
    user_input = st.text_area(
        "Enter your math meme input:",
        value="Math Meme Correction:\nIncorrect: 5-3-1 = 3?\nCorrect:"
    )

    if st.button("Repair Math Meme"):
        if user_input.strip() == "":
            st.warning("Please enter a math meme input following the required format.")
        else:
            with torch.no_grad():
                # Tokenize on CPU
                encoding = tokenizer(user_input, return_tensors="pt").to("cpu")
                outputs = model.generate(
                    input_ids=encoding.input_ids,
                    attention_mask=encoding.attention_mask,
                    max_new_tokens=generation_config.max_new_tokens,
                    temperature=generation_config.temperature,
                    top_p=generation_config.top_p,
                    pad_token_id=generation_config.pad_token_id
                )

            # Decode and display
            result = tokenizer.decode(outputs[0], skip_special_tokens=True)
            st.subheader("Repaired Math Meme")
            st.write(result)

            st.markdown("\n**Error Rating:** 90% sass, 10% patience (on CPU)")

if __name__ == "__main__":
    main()