textgames / agents /gemma_2_9b_it.py
fhudi's picture
Upload folder using huggingface_hub
8bf595d verified
#%%
import os
import re
#%%
from transformers import AutoTokenizer, AutoModelForCausalLM
from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS
from agents import run_with_agent
#%%
def _getenv_as_int(attr, default=None):
ret = os.getenv(attr, default)
return None if ret is None else int(ret)
GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
GEMMA_SIZE = int(os.getenv("TG_GEMMA_SIZE", "9")) # {3, 9, 27}
#%%
def gemma_postproc(response_txt, game_name, difficulty_level, *args, **kwargs):
# if game_name in [THE_GAMES[i] for i in ["1", "7"]]: # crossword
pat = re.compile(r'^```\n?([^`]*)\n?```')
match = pat.search(response_txt)
if match:
return match.group(1).strip().replace(" ", "")
# elif game_name == THE_GAMES["6"]: # anagram
pat = re.compile(r'\*\*\"?([^\"*]*)\"?\*\*')
match = pat.search(response_txt)
if match:
return match.group(1).strip()
return response_txt or ""
#%%
def get_gemma_response(texts, game_name, difficulty_level, turn, *args, **kwargs):
# global gen_model, tokenizer
messages = [
{"role": ("model" if i % 2 else "user"), "content": text}
for i, text in enumerate(texts)
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(gen_model.device)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
gen_model.generation_config.temperature = None
outputs = gen_model.generate(
input_ids,
max_new_tokens=100,
eos_token_id=terminators,
do_sample=False,
# temperature=.0,
# top_p=1,
)
response = outputs[0][input_ids.shape[-1]:]
return tokenizer.decode(response, skip_special_tokens=True).strip()
#%%
if __name__ == "__main__":
fp_out = (f"model_outputs/results_gemma-2-{GEMMA_SIZE}b-it"
f"{'.1s' if ONE_SHOT else '.zs'}"
f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
f".jsonl")
gen_model_checkpoint = f"google/gemma-2-{GEMMA_SIZE}b-it"
quantize = True
_kwargs = {
"device_map": "auto",
} if quantize else {}
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, **_kwargs)
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, **_kwargs)
print(f" > model.dtype: {gen_model.dtype}")
run_with_agent(
fp_out,
get_gemma_response,
gemma_postproc,
n_turns=N_TURNS,
game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
if SID_ST or SID_ED else None),
prepend_example=ONE_SHOT,
# remove_if_output_file_exist=False,
)