File size: 5,346 Bytes
d092d11
 
44d180e
d092d11
2b167f5
d092d11
 
09f135e
031841d
 
2b167f5
a1fddf9
01a2ce5
143b62d
1921336
143b62d
3f40f6e
c9c9f16
031841d
143b62d
01a2ce5
1921336
a1fddf9
031841d
01a2ce5
 
d092d11
203771e
031841d
1921336
d092d11
 
 
 
331a464
d092d11
031841d
1921336
 
d092d11
 
031841d
1921336
d092d11
1921336
031841d
01a2ce5
 
 
 
 
 
d092d11
 
 
 
 
031841d
 
d092d11
 
 
 
031841d
 
 
d092d11
 
 
 
 
 
 
031841d
 
 
c89910e
d092d11
 
 
 
 
 
031841d
d092d11
 
031841d
d092d11
031841d
d092d11
 
 
031841d
d092d11
031841d
d092d11
031841d
 
 
d092d11
 
031841d
 
 
 
d092d11
 
 
 
 
 
 
031841d
d092d11
 
 
 
 
 
 
 
 
031841d
d092d11
031841d
d092d11
031841d
 
 
d092d11
031841d
 
 
d092d11
 
031841d
 
 
 
 
 
 
 
d092d11
031841d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import logging
import os

import torch
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from vllm import LLM, SamplingParams

login(token=os.getenv("HF_TOKEN"))


class Model(torch.nn.Module):
    number_of_models = 0
    __model_list__ = [
        "Qwen/Qwen2-1.5B-Instruct",
        "lmsys/vicuna-7b-v1.5",
        "google-t5/t5-large",
        "mistralai/Mistral-7B-Instruct-v0.1",
        "meta-llama/Meta-Llama-3.1-8B-Instruct",
    ]

    def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None:
        super(Model, self).__init__()

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.name = model_name
        self.use_vllm = model_name != "google-t5/t5-large"

        logging.info(f"Start loading model {self.name}")

        if self.use_vllm:
            # 使用vLLM加载模型
            self.llm = LLM(
                model=model_name,
                dtype="half",
                tokenizer=model_name,
                trust_remote_code=True,
            )
        else:
            # 加载原始transformers模型
            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                model_name, torch_dtype=torch.bfloat16, device_map="auto"
            )
            self.model.eval()

        logging.info(f"Loaded model {self.name}")
        self.update()

    @classmethod
    def update(cls):
        cls.number_of_models += 1

    def gen(self, content_list, temp=0.001, max_length=500, do_sample=True):
        if self.use_vllm:
            sampling_params = SamplingParams(
                temperature=temp,
                max_tokens=max_length,
                # top_p=0.95 if do_sample else 1.0,
                stop_token_ids=[self.tokenizer.eos_token_id],
            )
            outputs = self.llm.generate(content_list, sampling_params)
            return [output.outputs[0].text for output in outputs]
        else:
            input_ids = self.tokenizer(
                content_list, return_tensors="pt", padding=True, truncation=True
            ).input_ids.to(self.model.device)
            outputs = self.model.generate(
                input_ids,
                max_new_tokens=max_length,
                do_sample=do_sample,
                temperature=temp,
                eos_token_id=self.tokenizer.eos_token_id,
            )
            return self.tokenizer.batch_decode(
                outputs[:, input_ids.shape[1] :], skip_special_tokens=True
            )

    def streaming(self, content_list, temp=0.001, max_length=500, do_sample=True):
        if self.use_vllm:
            sampling_params = SamplingParams(
                temperature=temp,
                max_tokens=max_length,
                top_p=0.95 if do_sample else 1.0,
                stop_token_ids=[self.tokenizer.eos_token_id],
            )
            outputs = self.llm.generate(content_list, sampling_params, stream=True)

            prev_token_ids = [[] for _ in content_list]

            for output in outputs:
                for i, request_output in enumerate(output.outputs):
                    current_token_ids = request_output.token_ids
                    new_token_ids = current_token_ids[len(prev_token_ids[i]) :]
                    prev_token_ids[i] = current_token_ids.copy()

                    for token_id in new_token_ids:
                        token_text = self.tokenizer.decode(
                            token_id, skip_special_tokens=True
                        )
                        yield i, token_text
        else:
            input_ids = self.tokenizer(
                content_list, return_tensors="pt", padding=True, truncation=True
            ).input_ids.to(self.model.device)

            gen_kwargs = {
                "input_ids": input_ids,
                "do_sample": do_sample,
                "temperature": temp,
                "eos_token_id": self.tokenizer.eos_token_id,
                "max_new_tokens": 1,
                "return_dict_in_generate": True,
                "output_scores": True,
            }

            generated_tokens = 0
            batch_size = input_ids.shape[0]
            active_sequences = torch.arange(batch_size)

            while generated_tokens < max_length and len(active_sequences) > 0:
                with torch.no_grad():
                    output = self.model.generate(**gen_kwargs)

                next_tokens = output.sequences[:, -1].unsqueeze(-1)

                for i, token in zip(active_sequences, next_tokens):
                    yield i.item(), self.tokenizer.decode(
                        token[0], skip_special_tokens=True
                    )

                gen_kwargs["input_ids"] = torch.cat(
                    [gen_kwargs["input_ids"], next_tokens], dim=-1
                )
                generated_tokens += 1

                completed = (
                    (next_tokens.squeeze(-1) == self.tokenizer.eos_token_id)
                    .nonzero()
                    .squeeze(-1)
                )
                active_sequences = torch.tensor(
                    [i for i in active_sequences if i not in completed]
                )
                if len(active_sequences) > 0:
                    gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]