File size: 1,113 Bytes
cf63839
 
 
 
 
 
 
 
 
 
 
733f97e
cf63839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
os.environ['TRANSFORMERS_CACHE'] = "data/parietal/store3/soda/lihu/hf_model/"
from transformers import AutoTokenizer
import transformers
import torch

model = "PY007/TinyLlama-1.1B-Chat-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    torch_dtype=torch.float32,
    device_map="auto",
)
CHAT_EOS_TOKEN_ID = 32002

def generate_answer(query, sample_num=3):
    #prompt = "Who is Lihu Chen?"
    formatted_prompt = (
        f"<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"

    )

    sequences = pipeline(
        formatted_prompt,
        do_sample=True,
        top_k=50,
        top_p = 0.9,
        num_return_sequences=sample_num,
        repetition_penalty=1.1,
        max_new_tokens=150,
        eos_token_id=CHAT_EOS_TOKEN_ID,
    )
    answers = list()
    for seq in sequences:
        answer = seq['generated_text'].replace(formatted_prompt, "")
        answers.append(answer)
        #print(f"Result: {answer}")
        #print("------------------------------------------")
    return answers