import logging import os import torch from huggingface_hub import login from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer from vllm import LLM, SamplingParams login(token=os.getenv("HF_TOKEN")) class Model(torch.nn.Module): number_of_models = 0 __model_list__ = [ "Qwen/Qwen2-1.5B-Instruct", "lmsys/vicuna-7b-v1.5", "google-t5/t5-large", "mistralai/Mistral-7B-Instruct-v0.1", "meta-llama/Meta-Llama-3.1-8B-Instruct", ] def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None: super(Model, self).__init__() self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.name = model_name self.use_vllm = model_name != "google-t5/t5-large" logging.info(f"Start loading model {self.name}") if self.use_vllm: # 使用vLLM加载模型 self.llm = LLM( model=model_name, dtype="half", tokenizer=model_name, trust_remote_code=True, ) else: # 加载原始transformers模型 self.model = AutoModelForSeq2SeqLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto" ) self.model.eval() logging.info(f"Loaded model {self.name}") self.update() @classmethod def update(cls): cls.number_of_models += 1 def gen(self, content_list, temp=0.001, max_length=500, do_sample=True): if self.use_vllm: sampling_params = SamplingParams( temperature=temp, max_tokens=max_length, # top_p=0.95 if do_sample else 1.0, stop_token_ids=[self.tokenizer.eos_token_id], ) outputs = self.llm.generate(content_list, sampling_params) return [output.outputs[0].text for output in outputs] else: input_ids = self.tokenizer( content_list, return_tensors="pt", padding=True, truncation=True ).input_ids.to(self.model.device) outputs = self.model.generate( input_ids, max_new_tokens=max_length, do_sample=do_sample, temperature=temp, eos_token_id=self.tokenizer.eos_token_id, ) return self.tokenizer.batch_decode( outputs[:, input_ids.shape[1] :], skip_special_tokens=True ) def streaming(self, content_list, temp=0.001, max_length=500, do_sample=True): if self.use_vllm: sampling_params = SamplingParams( temperature=temp, max_tokens=max_length, top_p=0.95 if do_sample else 1.0, stop_token_ids=[self.tokenizer.eos_token_id], ) outputs = self.llm.generate(content_list, sampling_params, stream=True) prev_token_ids = [[] for _ in content_list] for output in outputs: for i, request_output in enumerate(output.outputs): current_token_ids = request_output.token_ids new_token_ids = current_token_ids[len(prev_token_ids[i]) :] prev_token_ids[i] = current_token_ids.copy() for token_id in new_token_ids: token_text = self.tokenizer.decode( token_id, skip_special_tokens=True ) yield i, token_text else: input_ids = self.tokenizer( content_list, return_tensors="pt", padding=True, truncation=True ).input_ids.to(self.model.device) gen_kwargs = { "input_ids": input_ids, "do_sample": do_sample, "temperature": temp, "eos_token_id": self.tokenizer.eos_token_id, "max_new_tokens": 1, "return_dict_in_generate": True, "output_scores": True, } generated_tokens = 0 batch_size = input_ids.shape[0] active_sequences = torch.arange(batch_size) while generated_tokens < max_length and len(active_sequences) > 0: with torch.no_grad(): output = self.model.generate(**gen_kwargs) next_tokens = output.sequences[:, -1].unsqueeze(-1) for i, token in zip(active_sequences, next_tokens): yield i.item(), self.tokenizer.decode( token[0], skip_special_tokens=True ) gen_kwargs["input_ids"] = torch.cat( [gen_kwargs["input_ids"], next_tokens], dim=-1 ) generated_tokens += 1 completed = ( (next_tokens.squeeze(-1) == self.tokenizer.eos_token_id) .nonzero() .squeeze(-1) ) active_sequences = torch.tensor( [i for i in active_sequences if i not in completed] ) if len(active_sequences) > 0: gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]