YuvrajSingh9886 commited on
Commit
325f8af
·
verified ·
1 Parent(s): 5bb6ad4

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +151 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from config import ModelArgs
4
+ from inference import remove_prefix
5
+ from model import Llama
6
+ import torch
7
+ from inference_sft import topk_sampling
8
+ import os
9
+ import subprocess
10
+ import re
11
+ from tokenizer import Tokenizer
12
+ import torch.nn.functional as F
13
+ import shutil
14
+
15
+ # Define model paths
16
+ model_paths = {
17
+
18
+ "Pretrained": "weights/pretrained/https://huggingface.co/YuvrajSingh9886/StoryLlama/blob/main/snapshot_4650.pt"
19
+ }
20
+
21
+
22
+
23
+ # ACCESS_TOKEN = os.getenv("GDRIVE_ACCESS_TOKEN")
24
+
25
+
26
+ # def download_models():
27
+ for i in model_paths:
28
+ subprocess.run(["python", "download_model_weight.py", "--model_type", i.lower()], check=True)
29
+
30
+ # download_models()
31
+
32
+ tk = Tokenizer()
33
+ tk = tk.ready_tokenizer()
34
+
35
+
36
+
37
+ def beam_search(model, prompt, device, max_length=50, beam_width=5, top_k=50, temperature=1.0):
38
+ input_ids = tk.encode(prompt, return_tensors='pt').to(device)
39
+
40
+ # Initialize beams with initial input repeated beam_width times
41
+ beams = input_ids.repeat(beam_width, 1)
42
+ beam_scores = torch.zeros(beam_width).to(device) # Initialize scores
43
+
44
+ for _ in range(max_length):
45
+ with torch.no_grad():
46
+ outputs = model(beams)
47
+ logits = outputs[:, -1, :] # Get last token logits
48
+
49
+ # Apply temperature scaling
50
+ scaled_logits = logits / temperature
51
+
52
+ # Calculate log probabilities
53
+ log_probs = F.log_softmax(scaled_logits, dim=-1)
54
+
55
+ # Get top k candidates for each beam
56
+ topk_log_probs, topk_indices = torch.topk(log_probs, top_k, dim=-1)
57
+
58
+ # Generate all possible candidates
59
+ expanded_beams = beams.repeat_interleave(top_k, dim=0)
60
+ new_tokens = topk_indices.view(-1, 1)
61
+ candidate_beams = torch.cat([expanded_beams, new_tokens], dim=1)
62
+
63
+ # Calculate new scores for all candidates
64
+ expanded_scores = beam_scores.repeat_interleave(top_k)
65
+ candidate_scores = expanded_scores + topk_log_probs.view(-1)
66
+
67
+ # Select top beam_width candidates
68
+ top_scores, top_indices = candidate_scores.topk(beam_width)
69
+ beams = candidate_beams[top_indices]
70
+ beam_scores = top_scores
71
+
72
+ # Select best beam
73
+ best_idx = beam_scores.argmax()
74
+ best_sequence = beams[best_idx]
75
+ return tk.decode(best_sequence, skip_special_tokens=True)
76
+
77
+ # Function to load the selected model
78
+ def load_model(model_type):
79
+ model_path = model_paths[model_type]
80
+
81
+ # Check if the model exists; if not, download it
82
+ # if not os.path.exists(model_path):
83
+ # shutil.rmtree(model_path)
84
+ # os.mkdir(model_path)
85
+ # print(f"{model_type} Model not found! Downloading...")
86
+ # subprocess.run(["python", "download_model_weight.py", f"--{model_type.lower()}"], check=True)
87
+ # else:
88
+ # print(f"{model_type} Model found, skipping download.")
89
+
90
+ # Load the model
91
+ model = Llama(
92
+ device=ModelArgs.device,
93
+ embeddings_dims=ModelArgs.embeddings_dims,
94
+ no_of_decoder_layers=ModelArgs.no_of_decoder_layers,
95
+ block_size=ModelArgs.block_size,
96
+ vocab_size=ModelArgs.vocab_size,
97
+ dropout=ModelArgs.dropout
98
+ )
99
+ model = model.to(ModelArgs.device)
100
+
101
+ dict_model = torch.load(model_path, weights_only=False)
102
+ dict_model['MODEL_STATE'] = remove_prefix(dict_model['MODEL_STATE'], '_orig_mod.')
103
+ model.load_state_dict(dict_model['MODEL_STATE'])
104
+ model.eval()
105
+
106
+ return model
107
+
108
+
109
+ # download_models()
110
+ current_model = load_model("Pretrained")
111
+
112
+
113
+ def answer_question(model_type, prompt, temperature, top_k, max_length):
114
+ global current_model
115
+ # Reload model if the selected model type is different
116
+ if model_type == "Base (Pretrained)":
117
+ model_type = "Pretrained"
118
+ if model_paths[model_type] != model_paths.get(current_model, None):
119
+ current_model = load_model(model_type)
120
+
121
+
122
+ # formatted_prompt = f"### Instruction: Answer the following query. \n\n ### Input: {prompt}.\n\n ### Response: "
123
+
124
+ with torch.no_grad():
125
+ # if decoding_method == "Beam Search":
126
+ # generated_text = beam_search(current_model, formatted_prompt, device=ModelArgs.device,
127
+ # max_length=max_length, beam_width=5, top_k=top_k, temperature=temperature)
128
+ # else:
129
+ generated_text = topk_sampling(current_model, prompt, max_length=max_length,
130
+ top_k=top_k, temperature=temperature, device=ModelArgs.device)
131
+ return generated_text
132
+
133
+
134
+ iface = gr.Interface(
135
+ fn=answer_question,
136
+ inputs=[
137
+ gr.Dropdown(choices=["Base (Pretrained)"], value="Pretrained", label="Select Model"),
138
+ # gr.Dropdown(choices=["Top-K Sampling", "Beam Search"], value="Top-K Sampling", label="Decoding Method"),
139
+ gr.Textbox(label="Prompt", lines=5),
140
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature"),
141
+ gr.Slider(minimum=50,maximum = ModelArgs.vocab_size, value=50, step=1, label="Top-k"),
142
+ gr.Slider(minimum=10, maximum=ModelArgs.block_size, value=256, step=1, label="Max Length")
143
+ ],
144
+ outputs=gr.Textbox(label="Answer"),
145
+ title="StoryLlama",
146
+ description="Enter a prompt, select a model (Pretrained) and the model will generate a story!."
147
+ )
148
+
149
+ # Launch the Gradio app
150
+ if __name__ == "__main__":
151
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gdown
4
+ huggingface_hub