Chris-lab / utils /model.py
kz209
update format
031841d
import logging
import os
import torch
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from vllm import LLM, SamplingParams
login(token=os.getenv("HF_TOKEN"))
class Model(torch.nn.Module):
number_of_models = 0
__model_list__ = [
"Qwen/Qwen2-1.5B-Instruct",
"lmsys/vicuna-7b-v1.5",
"google-t5/t5-large",
"mistralai/Mistral-7B-Instruct-v0.1",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
]
def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None:
super(Model, self).__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.name = model_name
self.use_vllm = model_name != "google-t5/t5-large"
logging.info(f"Start loading model {self.name}")
if self.use_vllm:
# 使用vLLM加载模型
self.llm = LLM(
model=model_name,
dtype="half",
tokenizer=model_name,
trust_remote_code=True,
)
else:
# 加载原始transformers模型
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto"
)
self.model.eval()
logging.info(f"Loaded model {self.name}")
self.update()
@classmethod
def update(cls):
cls.number_of_models += 1
def gen(self, content_list, temp=0.001, max_length=500, do_sample=True):
if self.use_vllm:
sampling_params = SamplingParams(
temperature=temp,
max_tokens=max_length,
# top_p=0.95 if do_sample else 1.0,
stop_token_ids=[self.tokenizer.eos_token_id],
)
outputs = self.llm.generate(content_list, sampling_params)
return [output.outputs[0].text for output in outputs]
else:
input_ids = self.tokenizer(
content_list, return_tensors="pt", padding=True, truncation=True
).input_ids.to(self.model.device)
outputs = self.model.generate(
input_ids,
max_new_tokens=max_length,
do_sample=do_sample,
temperature=temp,
eos_token_id=self.tokenizer.eos_token_id,
)
return self.tokenizer.batch_decode(
outputs[:, input_ids.shape[1] :], skip_special_tokens=True
)
def streaming(self, content_list, temp=0.001, max_length=500, do_sample=True):
if self.use_vllm:
sampling_params = SamplingParams(
temperature=temp,
max_tokens=max_length,
top_p=0.95 if do_sample else 1.0,
stop_token_ids=[self.tokenizer.eos_token_id],
)
outputs = self.llm.generate(content_list, sampling_params, stream=True)
prev_token_ids = [[] for _ in content_list]
for output in outputs:
for i, request_output in enumerate(output.outputs):
current_token_ids = request_output.token_ids
new_token_ids = current_token_ids[len(prev_token_ids[i]) :]
prev_token_ids[i] = current_token_ids.copy()
for token_id in new_token_ids:
token_text = self.tokenizer.decode(
token_id, skip_special_tokens=True
)
yield i, token_text
else:
input_ids = self.tokenizer(
content_list, return_tensors="pt", padding=True, truncation=True
).input_ids.to(self.model.device)
gen_kwargs = {
"input_ids": input_ids,
"do_sample": do_sample,
"temperature": temp,
"eos_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": 1,
"return_dict_in_generate": True,
"output_scores": True,
}
generated_tokens = 0
batch_size = input_ids.shape[0]
active_sequences = torch.arange(batch_size)
while generated_tokens < max_length and len(active_sequences) > 0:
with torch.no_grad():
output = self.model.generate(**gen_kwargs)
next_tokens = output.sequences[:, -1].unsqueeze(-1)
for i, token in zip(active_sequences, next_tokens):
yield i.item(), self.tokenizer.decode(
token[0], skip_special_tokens=True
)
gen_kwargs["input_ids"] = torch.cat(
[gen_kwargs["input_ids"], next_tokens], dim=-1
)
generated_tokens += 1
completed = (
(next_tokens.squeeze(-1) == self.tokenizer.eos_token_id)
.nonzero()
.squeeze(-1)
)
active_sequences = torch.tensor(
[i for i in active_sequences if i not in completed]
)
if len(active_sequences) > 0:
gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]