File size: 18,923 Bytes
699d672
3b34d1e
2e54946
9b05877
d54daef
9b05877
 
 
 
 
 
699d672
480da6f
9b05877
480da6f
 
 
699d672
480da6f
3ef427a
4b69b6e
699d672
480da6f
699d672
2e54946
480da6f
24a51a6
 
480da6f
 
 
 
3ef427a
3b34d1e
2e54946
 
 
24a51a6
480da6f
 
 
 
 
 
 
24a51a6
480da6f
 
 
24a51a6
480da6f
24a51a6
480da6f
 
2e54946
480da6f
 
 
2e54946
480da6f
4b69b6e
23a7862
8264596
5f32a93
23fc124
8264596
480da6f
 
9b05877
 
23a7862
8264596
4b69b6e
 
c14b073
480da6f
24a51a6
480da6f
c14b073
23fc124
c14b073
 
 
23fc124
 
 
 
 
 
 
5f32a93
24a51a6
 
 
5f32a93
8264596
480da6f
23a7862
 
 
 
5f32a93
24a51a6
 
8264596
24a51a6
c14b073
23fc124
c14b073
 
23fc124
480da6f
4b69b6e
9b05877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24a51a6
 
 
8264596
4b69b6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ef427a
4b69b6e
 
3ef427a
ccd03a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b69b6e
23fc124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b69b6e
5f32a93
23fc124
3ef427a
4b69b6e
480da6f
9b05877
 
 
 
 
 
 
 
 
8264596
 
9b05877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b69b6e
9b05877
 
 
 
 
 
24a51a6
9b05877
 
 
 
 
 
 
 
 
 
 
 
 
 
5f32a93
4b69b6e
 
8264596
 
4b69b6e
 
 
8264596
23fc124
8264596
4b69b6e
 
 
 
8264596
23fc124
ccd03a6
4b69b6e
23fc124
 
 
 
 
 
 
4b69b6e
 
 
8264596
23fc124
8264596
4b69b6e
 
9b05877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b69b6e
 
8264596
4b69b6e
23fc124
4b69b6e
 
8264596
23a7862
 
 
8264596
4b69b6e
23fc124
8264596
4b69b6e
 
9b05877
 
24a51a6
 
 
2e54946
9b05877
3b34d1e
9b05877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699d672
480da6f
 
290d533
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
import gradio as gr
import torch
import time
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
from db import init_db, save_test_result, get_test_history, get_test_details

# --- Initialize Database ---
db_initialized = init_db()
if not db_initialized:
    print("WARNING: Database initialization failed. Test history will not be saved.")

# --- Configuration ---
MODEL_ID = "Qwen/Qwen2.5-Math-1.5B"  # Replace with actual ID if found
# --- Load Model and Tokenizer ---
print(f"Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype="auto",
    device_map="auto"
)
print("Model loaded successfully.")

# --- Generation Function (Returns response and token count) ---
def generate_response(messages, max_length=512, temperature=0.7, top_p=0.9):
    """Generate a response and return it along with the number of generated tokens."""
    num_generated_tokens = 0
    try:
        prompt_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        device = model.device
        model_inputs = tokenizer([prompt_text], return_tensors="pt").to(device)
        input_ids_len = model_inputs.input_ids.shape[-1]

        generation_kwargs = {
            "max_new_tokens": max_length,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": True,
            "pad_token_id": tokenizer.eos_token_id,
        }

        print("Generating response...")
        with torch.no_grad():
            generated_ids = model.generate(model_inputs.input_ids, **generation_kwargs)

        output_ids = generated_ids[0, input_ids_len:]
        num_generated_tokens = len(output_ids)
        response = tokenizer.decode(output_ids, skip_special_tokens=True)
        print("Generation complete.")
        return response.strip(), num_generated_tokens

    except Exception as e:
        print(f"Error during generation: {e}")
        return f"An error occurred: {str(e)}", num_generated_tokens

@spaces.GPU # Keep ZeroGPU decorator
def process_input(
    analysis_mode, # Mode selector
    player_stats,
    player_behavior_input,
    system_prompt, # Single system prompt from UI
    max_length,
    temperature,
    top_p,
    save_to_db=True  # New parameter to toggle database saving
):
    """Process inputs based on selected analysis mode using the provided system prompt."""
    print(f"GPU requested via decorator, starting processing in mode: {analysis_mode}")

    # Create the messages list using the system_prompt from the UI directly
    messages = []
    if system_prompt and system_prompt.strip():
        messages.append({"role": "system", "content": system_prompt})
    
    # Add content based on analysis mode (no empty content for any mode)
    if analysis_mode == "Frequency Only":
        user_content = f"Player Move Frequency Stats (Long-Term):\n{player_stats}"
        messages.append({"role": "user", "content": user_content})
    elif analysis_mode == "Behavior Analysis":
        user_content = player_behavior_input
        messages.append({"role": "user", "content": user_content})
    else:  # For Markov Prediction only mode
        # Don't add any user message - let system prompt handle everything
        user_content = ""
        # Note: We're not appending an empty user message here

    # --- Time Measurement Start ---
    start_time = time.time()

    # Generate response from the model
    response, generated_tokens = generate_response(
        messages,
        max_length=max_length,
        temperature=temperature,
        top_p=top_p
    )

    # --- Time Measurement End ---
    end_time = time.time()
    duration = round(end_time - start_time, 2)

    # For display purposes - show what was actually sent to the model
    if user_content:
        display_prompt = f"Selected Mode: {analysis_mode}\nSystem Prompt:\n{system_prompt}\n\n------\n\nUser Content:\n{user_content}"
    else:
        display_prompt = f"Selected Mode: {analysis_mode}\nSystem Prompt:\n{system_prompt}"

    print(f"Processing finished in {duration} seconds.")
    
    # Save to database if requested and if database is available
    if save_to_db and db_initialized:
        test_id = save_test_result(
            analysis_mode=analysis_mode,
            system_prompt=system_prompt,
            input_content=user_content if user_content else "",
            model_response=response,
            generation_time=duration,
            tokens_generated=generated_tokens,
            temperature=temperature,
            top_p=top_p,
            max_length=max_length
        )
        if test_id:
            print(f"Test saved to database with ID: {test_id}")
        else:
            print("Failed to save test to database")
    
    # Return all results including time and tokens
    return display_prompt, response, f"{duration} seconds", generated_tokens

# --- System Prompts (Defaults only, UI will hold the editable version) ---
DEFAULT_SYSTEM_PROMPT_FREQ = """You are an assistant that analyzes Rock-Paper-Scissors (RPS) player statistics. Your ONLY goal is to find the best single AI move to counter the player's MOST frequent move based on the provided frequency stats.

Follow these steps EXACTLY. Do NOT deviate.

Step 1: Identify Player's Most Frequent Move.
   - Look ONLY at the 'Player Move Frequency Stats'.
   - List the percentages: Rock (%), Paper (%), Scissors (%).
   - State which move name has the highest percentage number.

Step 2: Determine the Counter Move using RPS Rules.
   - REMEMBER THE RULES: Paper beats Rock. Rock beats Scissors. Scissors beats Paper.
   - Based *only* on the move identified in Step 1, state the single move name that beats it according to the rules. State the rule you used (e.g., "Paper beats Rock").

Step 3: Explain the Counter Choice.
   - Briefly state: "Playing [Counter Move from Step 2] is recommended because it directly beats the player's most frequent move, [Most Frequent Move from Step 1]."

Step 4: State Final Recommendation.
   - State *only* the recommended AI move name from Step 2. Example: "Recommendation: Paper"

Base your analysis strictly on the provided frequencies and the stated RPS rules.
"""

DEFAULT_SYSTEM_PROMPT_MARKOV = """You are analyzing a Rock-Paper-Scissors (RPS) game using a Markov transition matrix.

### TRANSITION MATRIX:
[
  [0.20, 0.60, 0.20],  # Row 0 (After Rock)
  [0.30, 0.10, 0.60],  # Row 1 (After Paper)
  [0.50, 0.30, 0.20]   # Row 2 (After Scissors)
]

### EXPLANATION:
- This matrix shows P(Next Move | Previous Move)
- Each row represents the previous move (0=Rock, 1=Paper, 2=Scissors)
- Each column represents the next move (0=Rock, 1=Paper, 2=Scissors)
- For example, entry [0,1]=0.60 means: after playing Rock, 60% chance of playing Paper next

### PLAYER INFORMATION:
- The player's last move was: Paper
- Our goal is to predict their most likely next move and determine our choice that counters the predicted move

### YOUR TASK:
1. Find the row in the matrix corresponding to the player's last move
2. From that row, identify which move has the highest probability value
3. That highest probability move is the player's predicted next move
4. Determine the optimal counter move using RPS rules:
   * Rock beats Scissors
   * Scissors beats Paper
   * Paper beats Rock

### SHOW YOUR MATHEMATICAL WORK:
- Identify the correct row number for the player's last move
- Extract all probability values from that row
- Compare the numerical values to find the maximum
- Apply game rules to determine the counter move

### OUTPUT FORMAT:
Player's Last Move: [Move]
Probabilities: [List the probabilities]
Predicted Next Move: [Move with highest probability]
Optimal Counter: [Move that beats the predicted move]
"""

DEFAULT_SYSTEM_PROMPT_BEHAVIOR = """You are an RPS assistant analyzing player behavior after wins, losses, and ties. Predict the player's next move and give counter strategy based on the Behavioral probabilities.

**Behavioral Probabilities P(Change/not change | Win/Loss/Tie):**
* P(not change | Win) = 0.70
* P(Change | Win) = 0.30
* P(not change | Loss) = 0.25
* P(Change | Loss) = 0.75
* P(not change | Tie) = 0.50
* P(Change | Tie) = 0.50

**Input Provided by User:**
* Player's Last Outcome: [Win/Loss/Tie]
* Player's Last Move: [Rock/Paper/Scissors]

**Your Task:**
1. Based on the Player's Last Outcome, determine the **Predicted Behavior** by comparing P(not change | Win/Loss/Tie) and P(Change | Win/Loss/Tie).
2. Determine the **Player's Predicted Next Move**:
   * If Predicted Behavior is "not change", predict the same move as Player's Last Move.
   * If Predicted Behavior is "Change", predict a move different from Player's Last Move (randomly select between the two remaining options with equal probability).
3. Recommend the **AI Counter Move** that beats the predicted player move:
   * Paper beats Rock
   * Rock beats Scissors
   * Scissors beats Paper

**Output Format:**
Predicted Behavior: [not change/Change] (Based on P(not change|Outcome)=[Prob], P(Change|Outcome)=[Prob])
Prediction Logic: [Brief explanation of your reasoning]
Predicted Player Move: [Rock/Paper/Scissors]
Recommended AI Counter: [Rock/Paper/Scissors]
"""

# --- Default Input Values ---
DEFAULT_PLAYER_STATS = "Rock: 40%\nPaper: 30%\nScissors: 30%"
DEFAULT_PLAYER_BEHAVIOR = "Player's Last Outcome: Win\nPlayer's Last Move: Rock"

# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    with gr.Tab("Model Testing"):
        gr.Markdown(f"# {MODEL_ID} - RPS Strategy Tester with Test History")
        gr.Markdown("Test model advice using Frequency Stats, Markov Predictions, or Win/Loss/Tie Behavior Analysis.")

        # Mode Selector - now with three options
        analysis_mode_selector = gr.Radio(
            label="Select Analysis Mode",
            choices=["Frequency Only", "Markov Prediction Only", "Behavior Analysis"],
            value="Frequency Only" # Default mode
        )

        # --- Visible System Prompt Textbox ---
        system_prompt_input = gr.Textbox(
                label="System Prompt (Edit based on selected mode)",
                value=DEFAULT_SYSTEM_PROMPT_FREQ, # Start with frequency prompt
                lines=15
            )

        # Input Sections (conditionally visible)
        with gr.Group(visible=True) as frequency_inputs: # Visible by default
            gr.Markdown("### Frequency Analysis Inputs")
            player_stats_input = gr.Textbox(
                label="Player Move Frequency Stats (Long-Term)", value=DEFAULT_PLAYER_STATS, lines=4,
                info="Overall player move distribution."
            )

        with gr.Group(visible=False) as markov_inputs: # Hidden by default
            gr.Markdown("### Markov Prediction Analysis Inputs")
            gr.Markdown("*Use the System Prompt field to directly input your Markov analysis instructions.*")

        # New behavior analysis inputs
        with gr.Group(visible=False) as behavior_inputs:
            gr.Markdown("### Win/Loss/Tie Behavior Analysis Inputs")
            player_behavior_input = gr.Textbox(
                label="Player's Last Outcome and Move", value=DEFAULT_PLAYER_BEHAVIOR, lines=4,
                info="Enter the last outcome (Win/Loss/Tie) and move (Rock/Paper/Scissors)."
            )

        # General Inputs / Parameters / Outputs
        with gr.Row():
            with gr.Column():
                gr.Markdown("#### Generation Parameters")
                max_length_slider = gr.Slider(minimum=50, maximum=1024, value=300, step=16, label="Max New Tokens")
                temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Temperature")
                top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P")
                
                # Add a checkbox to control saving to database
                save_to_db_checkbox = gr.Checkbox(
                    label="Save this test to database", 
                    value=True,
                    info="Store input and output in SQLite database for later reference"
                )

        submit_btn = gr.Button("Generate Response", variant="primary")

        with gr.Row():
            with gr.Column():
                gr.Markdown("#### Performance Metrics")
                time_output = gr.Textbox(label="Generation Time", interactive=False)
                tokens_output = gr.Number(label="Generated Tokens", interactive=False)
            with gr.Column():
                 gr.Markdown("""
                #### Testing Tips
                - Select the desired **Analysis Mode**.
                - Fill in the inputs for the **selected mode only**.
                - **Edit the System Prompt** above as needed for testing.
                - Use low **Temperature** for factual analysis.
                """)

        with gr.Row():
            final_prompt_display = gr.Textbox(
                label="Formatted Input Sent to Model (via Chat Template)", lines=20
            )
            response_display = gr.Textbox(
                label="Model Response", lines=20, show_copy_button=True
            )

    # Add a new tab for test history
    with gr.Tab("Test History"):
        gr.Markdown("### Saved Test Results")
        
        refresh_btn = gr.Button("Refresh History")
        
        # Display test history as a dataframe
        test_history_df = gr.Dataframe(
            headers=["Test ID", "Analysis Mode", "Timestamp", "Generation Time", "Tokens"],
            label="Recent Tests",
            interactive=False
        )
        
        # Add a number input to load a specific test
        test_id_input = gr.Number(
            label="Test ID", 
            precision=0,
            info="Enter a Test ID to load details"
        )
        load_test_btn = gr.Button("Load Test")
        
        # Display test details
        with gr.Group():
            test_mode_display = gr.Textbox(label="Analysis Mode", interactive=False)
            test_prompt_display = gr.Textbox(label="System Prompt", interactive=False, lines=8)
            test_input_display = gr.Textbox(label="Input Content", interactive=False, lines=4)
            test_response_display = gr.Textbox(label="Model Response", interactive=False, lines=8)
            
            with gr.Row():
                test_time_display = gr.Number(label="Generation Time (s)", interactive=False)
                test_tokens_display = gr.Number(label="Tokens Generated", interactive=False)
                test_temp_display = gr.Number(label="Temperature", interactive=False)
                test_topp_display = gr.Number(label="Top P", interactive=False)

    # --- Event Handlers ---

    # Function to update UI visibility AND system prompt content based on mode selection
    def update_ui_visibility_and_prompt(mode):
        if mode == "Frequency Only":
            return {
                frequency_inputs: gr.update(visible=True),
                markov_inputs: gr.update(visible=False),
                behavior_inputs: gr.update(visible=False),
                system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_FREQ) # Load Frequency prompt
            }
        elif mode == "Markov Prediction Only":
            return {
                frequency_inputs: gr.update(visible=False),
                markov_inputs: gr.update(visible=True),
                behavior_inputs: gr.update(visible=False),
                system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_MARKOV) # Load Markov prompt
            }
        elif mode == "Behavior Analysis":
            return {
                frequency_inputs: gr.update(visible=False),
                markov_inputs: gr.update(visible=False),
                behavior_inputs: gr.update(visible=True),
                system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_BEHAVIOR) # Load Behavior prompt
            }
        else: # Default case
             return {
                frequency_inputs: gr.update(visible=True),
                markov_inputs: gr.update(visible=False),
                behavior_inputs: gr.update(visible=False),
                system_prompt_input: gr.update(value=DEFAULT_SYSTEM_PROMPT_FREQ)
            }

    # Function to update test history display
    def update_test_history():
        if db_initialized:
            history = get_test_history(limit=20)
            return [[h[0], h[1], h[2], h[3], h[4]] for h in history]
        else:
            return [["N/A", "Database Not Available", "N/A", 0, 0]]
    
    # Function to load test details
    def load_test_details(test_id):
        if not db_initialized:
            return ["Database Not Available", "", "", "", 0, 0, 0, 0]
        
        test = get_test_details(test_id)
        if test:
            return [
                test["analysis_mode"],
                test["system_prompt"],
                test["input_content"] or "",
                test["model_response"],
                test["generation_time"],
                test["tokens_generated"],
                test["temperature"],
                test["top_p"]
            ]
        else:
            return ["Test not found", "", "", "", 0, 0, 0, 0]

    # Link the radio button change to the UI update function
    analysis_mode_selector.change(
        fn=update_ui_visibility_and_prompt, # Use the combined update function
        inputs=analysis_mode_selector,
        outputs=[frequency_inputs, markov_inputs, behavior_inputs, system_prompt_input] # Components to update
    )

    # Handle button click - Pass the single visible system prompt
    submit_btn.click(
        process_input,
        inputs=[
            analysis_mode_selector,
            player_stats_input,
            player_behavior_input,
            system_prompt_input, # Pass the visible system prompt textbox
            max_length_slider,
            temperature_slider,
            top_p_slider,
            save_to_db_checkbox  # Pass the checkbox value
        ],
        outputs=[
            final_prompt_display, response_display,
            time_output, tokens_output
        ]
    )
    
    # Connect buttons for test history tab
    refresh_btn.click(
        update_test_history,
        outputs=[test_history_df]
    )
    
    load_test_btn.click(
        load_test_details,
        inputs=[test_id_input],
        outputs=[
            test_mode_display, test_prompt_display, test_input_display, 
            test_response_display, test_time_display, test_tokens_display,
            test_temp_display, test_topp_display
        ]
    )
    
    # Initialize history on page load
    demo.load(update_test_history, outputs=[test_history_df])

# --- Launch the demo ---
if __name__ == "__main__":
    demo.launch()