Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,993 Bytes
7252f98 9aaa660 183a6ee 42ed840 7252f98 b36c7a9 b5abc9b 0e3f268 4cd194e 9aaa660 7287d81 9aaa660 b5abc9b b36c7a9 932e0b0 b36c7a9 932e0b0 b36c7a9 b5abc9b 9aaa660 932e0b0 7252f98 9d9e261 7252f98 a86f3af 7252f98 13b1370 034cffe a86f3af 13b1370 afe5959 f86092a 13b1370 afe5959 a86f3af 093a557 a86f3af 7252f98 f86092a afe5959 7252f98 afe5959 13b1370 f86092a 13b1370 afe5959 2ba8b3f a86f3af 2ba8b3f cfffc32 02f6e21 a86f3af cfffc32 13b1370 6034d83 13b1370 6034d83 ae08b25 cfffc32 2ba8b3f a86f3af 093a557 a86f3af 2ba8b3f cfffc32 2ba8b3f 92e70ff a494446 7252f98 7ec3bd7 dc427d9 b3de773 dc427d9 2ba8b3f 13b1370 7252f98 3f5293d a86f3af 3f5293d 2ba8b3f 3f5293d 7252f98 3f5293d 0e1a415 7252f98 3f5293d f86092a a86f3af 02f6e21 02eb393 fc90b53 7252f98 9756472 7252f98 3f5293d d29da35 a494446 b41f4d7 d29da35 a494446 8e98890 a494446 d29da35 9756472 a494446 d29da35 9756472 d29da35 9756472 d29da35 8cb5f7a d29da35 8e98890 a7ab71d a494446 d29da35 a494446 a86f3af 9756472 a494446 3f5293d a494446 a86f3af d29da35 7252f98 a494446 d29da35 a494446 150f6e1 d29da35 9756472 150f6e1 d29da35 9756472 d29da35 0e1a415 a7ab71d 8cb5f7a d29da35 7252f98 3f5293d 7c2923c 3f5293d 55b43fa 3f5293d a86f3af 8cb5f7a 2ba8b3f a86f3af 034cffe 8cb5f7a a86f3af 8cb5f7a 3f5293d 3f7f1a0 f7efac8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 |
import gradio as gr
import torch
import numpy as np
import json
import time
from transformers import AutoTokenizer
import os
import importlib
from huggingface_hub import hf_hub_download
from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, BidirectionalLlamaAttention, disable_dropout
import spaces
hf_token = os.getenv("HF_TOKEN")
# --- Load tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
vocab_size = len(tokenizer)
pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
eot_token_id = tokenizer.eos_token_id
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
# --- Load token probabilities ---
with open("token_probabilities.json") as f:
token_probs_dict = json.load(f)
token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
# def load_model():
# ckpt_path = hf_hub_download(
# repo_id="ruurd/tini_bi_m",
# filename="diffusion-model.pth",
# token=os.getenv("HF_TOKEN")
# )
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = torch.load(ckpt_path, map_location=device)
# model = disable_dropout(model)
# model.to(device)
# model.eval()
# return model
def load_model():
ckpt_path = hf_hub_download(
repo_id="ruurd/tini_bi",
filename="diffusion-model.pth",
token=os.getenv("HF_TOKEN"),
revision="5a22a8b6168466dbbf704efd00d8cbf2eee51426",
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Step 1: Create model from scratch
model = CustomTransformerModel(CustomTransformerConfig())
# Step 2: Load state_dict from full checkpoint
full_model = torch.load(ckpt_path, map_location=device)
# This handles both full model or just state_dict
try:
state_dict = full_model.state_dict()
except AttributeError:
state_dict = full_model # already a state_dict
# Step 3: Load weights (might print mismatches)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)
model = disable_dropout(model)
model.to(device)
model.eval()
return model
rng = np.random.default_rng()
# --- Utility Functions ---
def decode_tokens_safe(token_ids):
return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")
def find_answer_start(input_ids, marker_ids):
for i in range(len(input_ids) - len(marker_ids) + 1):
if input_ids[i:i + len(marker_ids)] == marker_ids:
return i + len(marker_ids)
return None
def get_noising_schedule(i, max_it, sharpness=5.0):
x = i / max_it
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0, mask_weight=0.0, clustering=0.5, noise_start = 1.0):
noised = input_ids.copy()
answer_len = len(noised) - answer_start
num_to_noise = int(threshold * answer_len * noise_start)
mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0]
if num_to_noise == 0:
return noised, []
mixed_probs = token_probabilities.copy()
# Apply EOT weighting
mixed_probs[eot_token_id] *= eot_weight
# Scale all other probabilities so they sum to 1 - mask_weight
total_other = mixed_probs.sum() - mixed_probs[mask_token_id]
scale = (1.0 - mask_weight) / total_other
mixed_probs *= scale
# Set mask_token_id to mask_weight explicitly
mixed_probs[mask_token_id] = mask_weight
num_clusters = max(1, int((1 - clustering) * num_to_noise))
cluster_size = max(1, int(num_to_noise / num_clusters))
noised_indices = set()
for _ in range(num_clusters):
center = rng.integers(answer_start, len(noised))
span_start = max(answer_start, center - cluster_size // 2)
span_end = min(len(noised), span_start + cluster_size)
noised_indices.update(range(span_start, span_end))
noised_indices = sorted(list(noised_indices))[:num_to_noise]
noise = rng.choice(np.arange(vocab_size), size=len(noised_indices), p=mixed_probs)
for idx, val in zip(noised_indices, noise):
noised[idx] = val
return noised, noised_indices
# Add new noising function
def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, eot_weight = 1.0, mask_weight = 0.0, noise_start = 1.0):
noised = input_ids.copy()
answer_len = len(input_ids) - answer_start
num_to_noise = int(threshold * answer_len * noise_start)
mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0]
if num_to_noise == 0:
return noised
raw_weights = 1.0 - np.array(confidences[answer_start:])
# Avoid zero-probability weights for selection
# If noise clipping == 1, all tokens have equal chance to be noised.
# If noise_clipping == 0.00001, all tokens are noised according to the confidence of the past prediction
raw_weights = np.clip(raw_weights, a_min = noise_clipping, a_max = None)
weights = raw_weights / raw_weights.sum()
if num_to_noise > len(weights):
num_to_noise = len(weights) # prevent oversampling
indices = rng.choice(
np.arange(answer_start, len(input_ids)),
size=num_to_noise,
replace=False,
p=weights
)
mixed_probs = token_probabilities.copy()
# Apply EOT weighting
mixed_probs[eot_token_id] *= eot_weight
# Scale all other probabilities so they sum to 1 - mask_weight
total_other = mixed_probs.sum() - mixed_probs[mask_token_id]
scale = (1.0 - mask_weight) / total_other
mixed_probs *= scale
# Set mask_token_id to mask_weight explicitly
mixed_probs[mask_token_id] = mask_weight
noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs)
for idx, val in zip(indices, noise):
noised[idx] = val
return noised
@spaces.GPU
def generate_diffusion_text(input_ids):
with torch.no_grad():
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
logits = model(input_ids=input_tensor)["logits"]
logits = logits.clamp(min=-1e4, max=1e4)
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
probs = torch.clamp(probs, min=1e-8, max=1.0)
assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
assert (probs >= 0).all(), "Negative probs!"
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
# Extract confidence of selected tokens
conf = probs[range(len(sampled)), sampled].cpu().numpy()
return sampled, conf
# --- Inference Wrapper ---
def diffusion_chat(question, eot_weight, mask_weight, max_it, pause_length, sharpness, clustering, noise_start, use_confidence_noising, noise_clipping):
placeholder = "What do you know about the city of New York?"
if question.strip() == "":
question = placeholder
print('started generation')
prompt = f"User: {question}\nAssistant:"
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
answer_start = find_answer_start(input_ids, assistant_marker_ids)
if answer_start is None:
yield "Error: Could not find Assistant marker in input."
return
if len(input_ids) < 256:
input_ids += [pad_token] * (256 - len(input_ids))
else:
input_ids = input_ids[:256]
ori_input_tokens = input_ids
current_tokens, just_noised_indices = noisify_answer(
input_ids, answer_start, threshold=1.0, eot_weight=eot_weight, mask_weight=mask_weight, clustering=clustering, noise_start = 1.0,
)
yield f"<b>Iteration 0 (initial noise):</b><br>" + tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).replace('\n', '<br>')
time.sleep(pause_length)
last_tokens = []
prev_decoded_tokens = []
for i in range(max_it):
print('Generating output')
# Model step
generated_tokens, confidences = generate_diffusion_text(current_tokens)
# Save full output for noising step
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
# --- GREEN HIGHLIGHT ---
decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
highlighted = []
for j, tok in enumerate(decoded_tokens):
tok_id = tokenizer.convert_tokens_to_ids(tok)
if tok_id == eot_token_id:
continue
token_str = tokenizer.convert_tokens_to_string([tok])
if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
highlighted.append(f'<span style="color:green">{token_str}</span>')
else:
highlighted.append(token_str)
prev_decoded_tokens = decoded_tokens
yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>')
time.sleep(pause_length)
# --- Early stopping ---
last_tokens.append(current_tokens)
if len(last_tokens) > 3:
last_tokens.pop(0)
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
yield f"<b>Stopped early after {i+1} iterations.</b>"
break
previous_tokens = current_tokens.copy()
# --- NOISING STEP ---
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
if use_confidence_noising:
noised_answer = confidence_guided_noising(
current_tokens, answer_start, confidences, noise_clipping, threshold=threshold, eot_weight=eot_weight, mask_weight=mask_weight, noise_start=noise_start
)
just_noised_indices = []
else:
noised_answer, just_noised_indices = noisify_answer(
current_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, mask_weight=mask_weight, clustering=clustering, noise_start = noise_start,
)
# Compose full input again: prompt + noised answer
current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
# --- RED HIGHLIGHT ---
decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
highlighted = []
for j, tok in enumerate(decoded_tokens):
tok_id = tokenizer.convert_tokens_to_ids(tok)
# if tok_id == eot_token_id:
# continue
token_str = tokenizer.convert_tokens_to_string([tok])
abs_idx = answer_start + j
if abs_idx in just_noised_indices:
highlighted.append(f'<span style="color:red">{token_str}</span>')
else:
highlighted.append(token_str)
yield f"<b>Iteration {i+1}/{max_it} (before noising):</b><br>" + "".join(highlighted).replace('\n', '<br>')
time.sleep(pause_length)
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
final_output = tokenizer.convert_tokens_to_string(final_tokens)
print(final_output)
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
# --- Gradio Interface ---
print("Loading model...")
model = load_model()
print("✅ Model loaded.")
demo = gr.Interface(
fn=diffusion_chat,
inputs=[
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
gr.Slider(0, 1, value=0.5, step=0.05, label="↓ = longer answers (EOT weight)"),
gr.Slider(0, 1, value=0.5, step=0.05, label="↓ = more random answers (MASK weight)"),
gr.Slider(1, 512, value=32, step=1, label="↑ = more iterations"),
gr.Slider(0.01, 5, value=0.01, step=0.01, label="↑ = longer pause (for visualization)"),
gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"),
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="↑ = more clustered noising (fewer, larger edits)"),
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="↑ = more noise (noise start)"),
gr.Checkbox(value=False, label="Use confidence-guided noising"),
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
],
outputs=[gr.HTML(label="Diffusion Output")],
title="Diffusion Language Model Chat",
theme="default",
description="This interface runs a diffusion-based language model to generate answers progressively."
)
demo.launch(share=True, allowed_paths=["."], ssr_mode=False)
|