Spaces:
Paused
Paused
Upload 2 files
Browse files- app.py +151 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import pipeline
|
3 |
+
from config import ModelArgs
|
4 |
+
from inference import remove_prefix
|
5 |
+
from model import Llama
|
6 |
+
import torch
|
7 |
+
from inference_sft import topk_sampling
|
8 |
+
import os
|
9 |
+
import subprocess
|
10 |
+
import re
|
11 |
+
from tokenizer import Tokenizer
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import shutil
|
14 |
+
|
15 |
+
# Define model paths
|
16 |
+
model_paths = {
|
17 |
+
|
18 |
+
"Pretrained": "weights/pretrained/https://huggingface.co/YuvrajSingh9886/StoryLlama/blob/main/snapshot_4650.pt"
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
# ACCESS_TOKEN = os.getenv("GDRIVE_ACCESS_TOKEN")
|
24 |
+
|
25 |
+
|
26 |
+
# def download_models():
|
27 |
+
for i in model_paths:
|
28 |
+
subprocess.run(["python", "download_model_weight.py", "--model_type", i.lower()], check=True)
|
29 |
+
|
30 |
+
# download_models()
|
31 |
+
|
32 |
+
tk = Tokenizer()
|
33 |
+
tk = tk.ready_tokenizer()
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
def beam_search(model, prompt, device, max_length=50, beam_width=5, top_k=50, temperature=1.0):
|
38 |
+
input_ids = tk.encode(prompt, return_tensors='pt').to(device)
|
39 |
+
|
40 |
+
# Initialize beams with initial input repeated beam_width times
|
41 |
+
beams = input_ids.repeat(beam_width, 1)
|
42 |
+
beam_scores = torch.zeros(beam_width).to(device) # Initialize scores
|
43 |
+
|
44 |
+
for _ in range(max_length):
|
45 |
+
with torch.no_grad():
|
46 |
+
outputs = model(beams)
|
47 |
+
logits = outputs[:, -1, :] # Get last token logits
|
48 |
+
|
49 |
+
# Apply temperature scaling
|
50 |
+
scaled_logits = logits / temperature
|
51 |
+
|
52 |
+
# Calculate log probabilities
|
53 |
+
log_probs = F.log_softmax(scaled_logits, dim=-1)
|
54 |
+
|
55 |
+
# Get top k candidates for each beam
|
56 |
+
topk_log_probs, topk_indices = torch.topk(log_probs, top_k, dim=-1)
|
57 |
+
|
58 |
+
# Generate all possible candidates
|
59 |
+
expanded_beams = beams.repeat_interleave(top_k, dim=0)
|
60 |
+
new_tokens = topk_indices.view(-1, 1)
|
61 |
+
candidate_beams = torch.cat([expanded_beams, new_tokens], dim=1)
|
62 |
+
|
63 |
+
# Calculate new scores for all candidates
|
64 |
+
expanded_scores = beam_scores.repeat_interleave(top_k)
|
65 |
+
candidate_scores = expanded_scores + topk_log_probs.view(-1)
|
66 |
+
|
67 |
+
# Select top beam_width candidates
|
68 |
+
top_scores, top_indices = candidate_scores.topk(beam_width)
|
69 |
+
beams = candidate_beams[top_indices]
|
70 |
+
beam_scores = top_scores
|
71 |
+
|
72 |
+
# Select best beam
|
73 |
+
best_idx = beam_scores.argmax()
|
74 |
+
best_sequence = beams[best_idx]
|
75 |
+
return tk.decode(best_sequence, skip_special_tokens=True)
|
76 |
+
|
77 |
+
# Function to load the selected model
|
78 |
+
def load_model(model_type):
|
79 |
+
model_path = model_paths[model_type]
|
80 |
+
|
81 |
+
# Check if the model exists; if not, download it
|
82 |
+
# if not os.path.exists(model_path):
|
83 |
+
# shutil.rmtree(model_path)
|
84 |
+
# os.mkdir(model_path)
|
85 |
+
# print(f"{model_type} Model not found! Downloading...")
|
86 |
+
# subprocess.run(["python", "download_model_weight.py", f"--{model_type.lower()}"], check=True)
|
87 |
+
# else:
|
88 |
+
# print(f"{model_type} Model found, skipping download.")
|
89 |
+
|
90 |
+
# Load the model
|
91 |
+
model = Llama(
|
92 |
+
device=ModelArgs.device,
|
93 |
+
embeddings_dims=ModelArgs.embeddings_dims,
|
94 |
+
no_of_decoder_layers=ModelArgs.no_of_decoder_layers,
|
95 |
+
block_size=ModelArgs.block_size,
|
96 |
+
vocab_size=ModelArgs.vocab_size,
|
97 |
+
dropout=ModelArgs.dropout
|
98 |
+
)
|
99 |
+
model = model.to(ModelArgs.device)
|
100 |
+
|
101 |
+
dict_model = torch.load(model_path, weights_only=False)
|
102 |
+
dict_model['MODEL_STATE'] = remove_prefix(dict_model['MODEL_STATE'], '_orig_mod.')
|
103 |
+
model.load_state_dict(dict_model['MODEL_STATE'])
|
104 |
+
model.eval()
|
105 |
+
|
106 |
+
return model
|
107 |
+
|
108 |
+
|
109 |
+
# download_models()
|
110 |
+
current_model = load_model("Pretrained")
|
111 |
+
|
112 |
+
|
113 |
+
def answer_question(model_type, prompt, temperature, top_k, max_length):
|
114 |
+
global current_model
|
115 |
+
# Reload model if the selected model type is different
|
116 |
+
if model_type == "Base (Pretrained)":
|
117 |
+
model_type = "Pretrained"
|
118 |
+
if model_paths[model_type] != model_paths.get(current_model, None):
|
119 |
+
current_model = load_model(model_type)
|
120 |
+
|
121 |
+
|
122 |
+
# formatted_prompt = f"### Instruction: Answer the following query. \n\n ### Input: {prompt}.\n\n ### Response: "
|
123 |
+
|
124 |
+
with torch.no_grad():
|
125 |
+
# if decoding_method == "Beam Search":
|
126 |
+
# generated_text = beam_search(current_model, formatted_prompt, device=ModelArgs.device,
|
127 |
+
# max_length=max_length, beam_width=5, top_k=top_k, temperature=temperature)
|
128 |
+
# else:
|
129 |
+
generated_text = topk_sampling(current_model, prompt, max_length=max_length,
|
130 |
+
top_k=top_k, temperature=temperature, device=ModelArgs.device)
|
131 |
+
return generated_text
|
132 |
+
|
133 |
+
|
134 |
+
iface = gr.Interface(
|
135 |
+
fn=answer_question,
|
136 |
+
inputs=[
|
137 |
+
gr.Dropdown(choices=["Base (Pretrained)"], value="Pretrained", label="Select Model"),
|
138 |
+
# gr.Dropdown(choices=["Top-K Sampling", "Beam Search"], value="Top-K Sampling", label="Decoding Method"),
|
139 |
+
gr.Textbox(label="Prompt", lines=5),
|
140 |
+
gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature"),
|
141 |
+
gr.Slider(minimum=50,maximum = ModelArgs.vocab_size, value=50, step=1, label="Top-k"),
|
142 |
+
gr.Slider(minimum=10, maximum=ModelArgs.block_size, value=256, step=1, label="Max Length")
|
143 |
+
],
|
144 |
+
outputs=gr.Textbox(label="Answer"),
|
145 |
+
title="StoryLlama",
|
146 |
+
description="Enter a prompt, select a model (Pretrained) and the model will generate a story!."
|
147 |
+
)
|
148 |
+
|
149 |
+
# Launch the Gradio app
|
150 |
+
if __name__ == "__main__":
|
151 |
+
iface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
gdown
|
4 |
+
huggingface_hub
|