File size: 3,257 Bytes
8bf595d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#%%
import os
import re

#%%
from transformers import AutoTokenizer, AutoModelForCausalLM
from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS
from agents import run_with_agent


#%%
def _getenv_as_int(attr, default=None):
    ret = os.getenv(attr, default)
    return None if ret is None else int(ret)


GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
GEMMA_SIZE = int(os.getenv("TG_GEMMA_SIZE", "9"))    # {3, 9, 27}


#%%
def gemma_postproc(response_txt, game_name, difficulty_level, *args, **kwargs):
    # if game_name in [THE_GAMES[i] for i in ["1", "7"]]:  # crossword
    pat = re.compile(r'^```\n?([^`]*)\n?```')
    match = pat.search(response_txt)
    if match:
        return match.group(1).strip().replace(" ", "")

    # elif game_name == THE_GAMES["6"]:  # anagram
    pat = re.compile(r'\*\*\"?([^\"*]*)\"?\*\*')
    match = pat.search(response_txt)
    if match:
        return match.group(1).strip()

    return response_txt or ""


#%%
def get_gemma_response(texts, game_name, difficulty_level, turn, *args, **kwargs):
    # global gen_model, tokenizer
    messages = [
        {"role": ("model" if i % 2 else "user"), "content": text}
        for i, text in enumerate(texts)
    ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(gen_model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    gen_model.generation_config.temperature = None
    outputs = gen_model.generate(
        input_ids,
        max_new_tokens=100,
        eos_token_id=terminators,
        do_sample=False,
        # temperature=.0,
        # top_p=1,
    )

    response = outputs[0][input_ids.shape[-1]:]
    return tokenizer.decode(response, skip_special_tokens=True).strip()


#%%
if __name__ == "__main__":
    fp_out = (f"model_outputs/results_gemma-2-{GEMMA_SIZE}b-it"
              f"{'.1s' if ONE_SHOT else '.zs'}"
              f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
              f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
              f".jsonl")
    gen_model_checkpoint = f"google/gemma-2-{GEMMA_SIZE}b-it"

    quantize = True
    _kwargs = {
        "device_map": "auto",
    } if quantize else {}

    gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, **_kwargs)
    tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, **_kwargs)
    print(f"    > model.dtype: {gen_model.dtype}")

    run_with_agent(
        fp_out,
        get_gemma_response,
        gemma_postproc,
        n_turns=N_TURNS,
        game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
        level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
        sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
                     if SID_ST or SID_ED else None),
        prepend_example=ONE_SHOT,
        # remove_if_output_file_exist=False,
    )