|
import os |
|
import torch |
|
import random |
|
import numpy as np |
|
import argparse |
|
import json |
|
import cohere |
|
from openai import OpenAI |
|
|
|
from tqdm import tqdm |
|
|
|
from collections import Counter |
|
|
|
from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM |
|
import hashlib |
|
|
|
|
|
OPENAI_TOKEN = "" |
|
COHERE_TOKEN = "" |
|
HF_TOKEN = "" |
|
|
|
|
|
def argmax(array): |
|
"""argmax with deterministic pseudorandom tie breaking.""" |
|
max_indices = np.arange(len(array))[array == np.max(array)] |
|
idx = int(hashlib.sha256(np.asarray(array).tobytes()).hexdigest(), 16) % len(max_indices) |
|
return max_indices[idx] |
|
|
|
|
|
def logsumexp(x): |
|
c = x.max() |
|
return c + np.log(np.sum(np.exp(x - c))) |
|
|
|
|
|
def normalize(x): |
|
x = np.array(x) |
|
return np.exp(x - logsumexp(x)) |
|
|
|
|
|
def set_seed(seed): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
|
|
|
|
def get_commandr_chat_response(gen_model, gen_model_checkpoint, text, seed): |
|
response = gen_model.chat( |
|
model="command-r", |
|
message=text, |
|
temperature=0, |
|
max_tokens=64, |
|
seed=seed, |
|
p=1 |
|
) |
|
return response.text |
|
|
|
|
|
def get_mt0_response(gen_model, tokenizer, gen_model_checkpoint, text, seed): |
|
input_ids = tokenizer.encode(text, return_tensors="pt").to(gen_model.device) |
|
|
|
outputs = gen_model.generate( |
|
input_ids, |
|
max_new_tokens=10, |
|
do_sample=True, |
|
temperature=0.2, |
|
top_p=1 |
|
) |
|
|
|
response = outputs[0] |
|
return tokenizer.decode(response, skip_special_tokens=True) |
|
|
|
|
|
def get_gemma_response(gen_model, tokenizer, gen_model_checkpoint, text, seed): |
|
messages = [ |
|
{"role": "user", "content": text}, |
|
] |
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(gen_model.device) |
|
|
|
terminators = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.convert_tokens_to_ids("<|eot_id|>") |
|
] |
|
|
|
outputs = gen_model.generate( |
|
input_ids, |
|
max_new_tokens=10, |
|
eos_token_id=terminators, |
|
do_sample=True, |
|
temperature=0.2, |
|
top_p=1 |
|
) |
|
|
|
response = outputs[0][input_ids.shape[-1]:] |
|
return tokenizer.decode(response, skip_special_tokens=True) |
|
|
|
|
|
def get_mistral_instruct_chat_response(gen_model, tokenizer, gen_model_checkpoint, text, seed): |
|
messages = [ |
|
{"role": "user", "content": text}, |
|
] |
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(gen_model.device) |
|
|
|
terminators = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.convert_tokens_to_ids("<|eot_id|>") |
|
] |
|
|
|
outputs = gen_model.generate( |
|
input_ids, |
|
max_new_tokens=10, |
|
eos_token_id=terminators, |
|
do_sample=True, |
|
temperature=0.2, |
|
top_p=1 |
|
) |
|
|
|
response = outputs[0][input_ids.shape[-1]:] |
|
return tokenizer.decode(response, skip_special_tokens=True) |
|
|
|
|
|
def get_llama3_instruct_chat_response(gen_model, tokenizer, gen_model_checkpoint, text, seed): |
|
messages = [ |
|
{"role": "user", "content": text}, |
|
] |
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(gen_model.device) |
|
|
|
terminators = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.convert_tokens_to_ids("<|eot_id|>") |
|
] |
|
|
|
outputs = gen_model.generate( |
|
input_ids, |
|
max_new_tokens=10, |
|
eos_token_id=terminators, |
|
do_sample=True, |
|
temperature=0.2, |
|
top_p=1 |
|
) |
|
|
|
response = outputs[0][input_ids.shape[-1]:] |
|
return tokenizer.decode(response, skip_special_tokens=True) |
|
|
|
|
|
def get_openai_chat_response(gen_model, gen_model_checkpoint, text, seed): |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": text |
|
} |
|
] |
|
response = gen_model.chat.completions.create( |
|
model=gen_model_checkpoint, |
|
messages=messages, |
|
temperature=0, |
|
max_tokens=64, |
|
top_p=1, |
|
seed=seed |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
def load_model(gen_model_checkpoint, load_in_8bit=False): |
|
gen_model = None |
|
tokenizer = None |
|
|
|
if "mistralai/Mistral-7B-Instruct-v0.3" in gen_model_checkpoint or "meta-llama/Meta-Llama-3-8B-Instruct" in gen_model_checkpoint or "google/gemma-1.1-7b-it" in gen_model_checkpoint: |
|
if load_in_8bit: |
|
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto", |
|
load_in_8bit=True) |
|
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto", |
|
load_in_8bit=True) |
|
else: |
|
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN) |
|
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN) |
|
elif "CohereForAI/aya-101" in gen_model_checkpoint or "bigscience/mt0" in gen_model_checkpoint: |
|
if load_in_8bit: |
|
gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto", |
|
load_in_8bit=True) |
|
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto", |
|
load_in_8bit=True) |
|
else: |
|
gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN) |
|
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN) |
|
elif "facebook/xglm" in gen_model_checkpoint or "bigscience/bloomz" in gen_model_checkpoint or "aya-23-8B" in args.gen_model_checkpoint: |
|
if load_in_8bit: |
|
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto", |
|
load_in_8bit=True) |
|
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto", |
|
load_in_8bit=True) |
|
else: |
|
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN) |
|
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN) |
|
elif "gpt-3.5-turbo" in gen_model_checkpoint or "gpt-4" in gen_model_checkpoint: |
|
gen_model = OpenAI(api_key=OPENAI_TOKEN) |
|
elif "command-r" in gen_model_checkpoint: |
|
gen_model = cohere.Client(COHERE_TOKEN) |
|
|
|
return gen_model, tokenizer |
|
|
|
|