Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- app.py +16 -0
- llama_generate.py +162 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from llama_generate import run
|
3 |
+
|
4 |
+
|
5 |
+
def greet(query):
|
6 |
+
results = run(query)
|
7 |
+
return results
|
8 |
+
|
9 |
+
|
10 |
+
sample_list = [
|
11 |
+
"Who is Lihu Chen?.",
|
12 |
+
"Who is Gaël Varoquaux?"
|
13 |
+
]
|
14 |
+
|
15 |
+
iface = gr.Interface(fn=greet, inputs="text", outputs="text", examples=sample_list, cache_examples=False)
|
16 |
+
iface.launch()
|
llama_generate.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
+
import torch
|
3 |
+
from nltk.tokenize import sent_tokenize
|
4 |
+
|
5 |
+
torch.device('cuda' if torch.cuda.is_available() else 'cpu') # the device to load the model onto
|
6 |
+
model_name_or_path = "TheBloke/Llama-2-7b-Chat-GPTQ"
|
7 |
+
|
8 |
+
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
|
9 |
+
#torch_dtype=torch.float16,
|
10 |
+
device_map="auto",
|
11 |
+
trust_remote_code=False,
|
12 |
+
revision="main")
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
|
17 |
+
tokenizer.pad_token = tokenizer.unk_token
|
18 |
+
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
19 |
+
# model.resize_token_embeddings(len(tokenizer))
|
20 |
+
|
21 |
+
def clean(result):
|
22 |
+
special_token = ['<s>', '</s>', '<unk>']
|
23 |
+
|
24 |
+
result = result.split("[/INST]")[-1].strip()
|
25 |
+
|
26 |
+
# context = "[INST] {a} [/INST]".format(a=content)
|
27 |
+
#result = result.replace(context, '')
|
28 |
+
for token in special_token:
|
29 |
+
result = result.replace(token, '').strip()
|
30 |
+
return result.strip()
|
31 |
+
|
32 |
+
def single_generate(query):
|
33 |
+
|
34 |
+
messages = [
|
35 |
+
{"role": "user", "content": query},
|
36 |
+
]
|
37 |
+
|
38 |
+
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
|
39 |
+
#print(encodeds)
|
40 |
+
|
41 |
+
model_inputs = encodeds.to(device)
|
42 |
+
model.to(device)
|
43 |
+
|
44 |
+
generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True, temperature=1.0)
|
45 |
+
decoded = tokenizer.batch_decode(generated_ids)
|
46 |
+
results = list()
|
47 |
+
for index, result in enumerate(decoded):
|
48 |
+
#print(result)
|
49 |
+
result = clean(result)
|
50 |
+
#print('query = ', query, ' result = ', result)
|
51 |
+
results.append(result)
|
52 |
+
return results
|
53 |
+
|
54 |
+
|
55 |
+
def prepare_input(contents):
|
56 |
+
temp = list()
|
57 |
+
for content in contents:
|
58 |
+
messages = [
|
59 |
+
{"role": "user", "content": content}
|
60 |
+
]
|
61 |
+
#print('messages = ', messages)
|
62 |
+
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", max_length=100, padding=True)
|
63 |
+
# print(encodeds.size())
|
64 |
+
# print(encodeds)
|
65 |
+
temp.append(encodeds[0])
|
66 |
+
|
67 |
+
batch_encoded = torch.stack(temp)
|
68 |
+
return batch_encoded
|
69 |
+
|
70 |
+
def batch_generate(queries):
|
71 |
+
model_inputs = prepare_input(queries).to(device)
|
72 |
+
model.to(device)
|
73 |
+
|
74 |
+
generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True, temperature=1.0)
|
75 |
+
decoded = tokenizer.batch_decode(generated_ids)
|
76 |
+
results = list()
|
77 |
+
for index, result in enumerate(decoded):
|
78 |
+
query = queries[index]
|
79 |
+
result = clean(result)
|
80 |
+
#print('query = ', query, ' result = ', result)
|
81 |
+
results.append(result)
|
82 |
+
return results
|
83 |
+
|
84 |
+
|
85 |
+
def get_yes_or_no(result):
|
86 |
+
if 'yes' in str.lower(result)[:5]:return 'Yes'
|
87 |
+
if 'no' in str.lower(result)[:5]:return 'No'
|
88 |
+
return 'N/A'
|
89 |
+
|
90 |
+
|
91 |
+
def check_score(context, sentences):
|
92 |
+
score_mapping = {'Yes':1.0, 'No':0.0}
|
93 |
+
template = """
|
94 |
+
Context: {a}
|
95 |
+
Sentence: {b}
|
96 |
+
Is the sentence supported by the context above?
|
97 |
+
Answer Yes or No (Don't give explanations):
|
98 |
+
"""
|
99 |
+
scores, results = list(), list()
|
100 |
+
for sentence in sentences:
|
101 |
+
content = template.format(a=context.strip().replace('/n', ''), b=sentence.strip().replace('/n', ''))
|
102 |
+
result = single_generate(content)[0]
|
103 |
+
#result = clean(result, context)
|
104 |
+
#print('results', results)
|
105 |
+
results.append(result)
|
106 |
+
|
107 |
+
results = [get_yes_or_no(r) for r in results]
|
108 |
+
scores = [score_mapping.get(result, 0.5) for result in results]
|
109 |
+
|
110 |
+
# for sent, score in zip(sentences, scores):
|
111 |
+
# print(sent.strip(), score)
|
112 |
+
#result_string += sent + ' ({a})'.format(a=score)
|
113 |
+
|
114 |
+
return scores
|
115 |
+
|
116 |
+
|
117 |
+
def sample_answer(query, num):
|
118 |
+
answers = list()
|
119 |
+
for _ in range(num):
|
120 |
+
answer = single_generate(query)
|
121 |
+
answers.append(answer[0])
|
122 |
+
return answers
|
123 |
+
|
124 |
+
|
125 |
+
def run(query, sample_size=5):
|
126 |
+
sampled = sample_answer(query, sample_size+1)
|
127 |
+
answer = sampled[0]
|
128 |
+
proofs = sampled[1:]
|
129 |
+
sentences = sent_tokenize(answer)
|
130 |
+
|
131 |
+
all_scores = list()
|
132 |
+
for proof in proofs:
|
133 |
+
scores = check_score(proof, sentences)
|
134 |
+
all_scores.append(scores)
|
135 |
+
|
136 |
+
final_content = ''
|
137 |
+
avg_confidence = list()
|
138 |
+
for index, scores in enumerate(zip(*all_scores)):
|
139 |
+
sentence_confidence = sum(scores) / len(scores)
|
140 |
+
avg_confidence.append(sentence_confidence)
|
141 |
+
final_content += sentences[index].strip() + ' ({a}) '.format(a=sentence_confidence)
|
142 |
+
avg_confidence = sum(avg_confidence) / len(avg_confidence)
|
143 |
+
final_content += '\nThe confidence score of this answer is {a}'.format(a=avg_confidence)
|
144 |
+
return final_content
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
# result = sample_answer(query="Who is Lihu Chen?", num=5)
|
148 |
+
# print(result)
|
149 |
+
#batch_generate(["Who is Lihu Chen?", "Who is Lihu Chen?"])
|
150 |
+
|
151 |
+
# context = """
|
152 |
+
# Lihu Chen is an American writer and artist who works in comics. They received their degree in psychology from California State University, Fullerton and have worked on titles such as "The Gathering Storm" and "Heartthrob".
|
153 |
+
# """
|
154 |
+
# sentences = sent_tokenize("""
|
155 |
+
# Lihu Chen is an American writer and artist who works in comics. They received their degree in psychology from California State University, Fullerton and have worked on titles such as "The Gathering Storm" and "Heartthrob".
|
156 |
+
# """)
|
157 |
+
# result = check_score(context, sentences)
|
158 |
+
# print(result)
|
159 |
+
# result = """
|
160 |
+
|
161 |
+
answer = run(query='WHo is Lihu Chen?', sample_size=10)
|
162 |
+
print(answer)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
nltk
|
2 |
+
torch==2.1.0
|
3 |
+
transformers==4.35.0.dev0
|