import gradio as gr
import torch
import numpy as np
import json
import time
from transformers import AutoTokenizer
import os
import importlib
from huggingface_hub import hf_hub_download
from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, BidirectionalLlamaAttention, disable_dropout
import spaces
hf_token = os.getenv("HF_TOKEN")
# --- Load tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
vocab_size = len(tokenizer)
pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
eot_token_id = tokenizer.eos_token_id
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
# --- Load token probabilities ---
with open("token_probabilities.json") as f:
token_probs_dict = json.load(f)
token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
# def load_model():
# ckpt_path = hf_hub_download(
# repo_id="ruurd/tini_bi_m",
# filename="diffusion-model.pth",
# token=os.getenv("HF_TOKEN")
# )
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = torch.load(ckpt_path, map_location=device)
# model = disable_dropout(model)
# model.to(device)
# model.eval()
# return model
def load_model():
ckpt_path = hf_hub_download(
repo_id="ruurd/tini_bi",
filename="diffusion-model.pth",
token=os.getenv("HF_TOKEN"),
revision="5a22a8b6168466dbbf704efd00d8cbf2eee51426",
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Step 1: Create model from scratch
model = CustomTransformerModel(CustomTransformerConfig())
# Step 2: Load state_dict from full checkpoint
full_model = torch.load(ckpt_path, map_location=device)
# This handles both full model or just state_dict
try:
state_dict = full_model.state_dict()
except AttributeError:
state_dict = full_model # already a state_dict
# Step 3: Load weights (might print mismatches)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)
model = disable_dropout(model)
model.to(device)
model.eval()
return model
rng = np.random.default_rng()
# --- Utility Functions ---
def decode_tokens_safe(token_ids):
return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")
def find_answer_start(input_ids, marker_ids):
for i in range(len(input_ids) - len(marker_ids) + 1):
if input_ids[i:i + len(marker_ids)] == marker_ids:
return i + len(marker_ids)
return None
def get_noising_schedule(i, max_it, sharpness=5.0):
x = i / max_it
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0, mask_weight=0.0, clustering=0.5, noise_start = 1.0):
noised = input_ids.copy()
answer_len = len(noised) - answer_start
num_to_noise = int(threshold * answer_len * noise_start)
mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0]
if num_to_noise == 0:
return noised, []
mixed_probs = token_probabilities.copy()
# Apply EOT weighting
mixed_probs[eot_token_id] *= eot_weight
# Scale all other probabilities so they sum to 1 - mask_weight
total_other = mixed_probs.sum() - mixed_probs[mask_token_id]
scale = (1.0 - mask_weight) / total_other
mixed_probs *= scale
# Set mask_token_id to mask_weight explicitly
mixed_probs[mask_token_id] = mask_weight
num_clusters = max(1, int((1 - clustering) * num_to_noise))
cluster_size = max(1, int(num_to_noise / num_clusters))
noised_indices = set()
for _ in range(num_clusters):
center = rng.integers(answer_start, len(noised))
span_start = max(answer_start, center - cluster_size // 2)
span_end = min(len(noised), span_start + cluster_size)
noised_indices.update(range(span_start, span_end))
noised_indices = sorted(list(noised_indices))[:num_to_noise]
noise = rng.choice(np.arange(vocab_size), size=len(noised_indices), p=mixed_probs)
for idx, val in zip(noised_indices, noise):
noised[idx] = val
return noised, noised_indices
# Add new noising function
def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, eot_weight = 1.0, mask_weight = 0.0, noise_start = 1.0):
noised = input_ids.copy()
answer_len = len(input_ids) - answer_start
num_to_noise = int(threshold * answer_len * noise_start)
mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0]
if num_to_noise == 0:
return noised
raw_weights = 1.0 - np.array(confidences[answer_start:])
# Avoid zero-probability weights for selection
# If noise clipping == 1, all tokens have equal chance to be noised.
# If noise_clipping == 0.00001, all tokens are noised according to the confidence of the past prediction
raw_weights = np.clip(raw_weights, a_min = noise_clipping, a_max = None)
weights = raw_weights / raw_weights.sum()
if num_to_noise > len(weights):
num_to_noise = len(weights) # prevent oversampling
indices = rng.choice(
np.arange(answer_start, len(input_ids)),
size=num_to_noise,
replace=False,
p=weights
)
mixed_probs = token_probabilities.copy()
# Apply EOT weighting
mixed_probs[eot_token_id] *= eot_weight
# Scale all other probabilities so they sum to 1 - mask_weight
total_other = mixed_probs.sum() - mixed_probs[mask_token_id]
scale = (1.0 - mask_weight) / total_other
mixed_probs *= scale
# Set mask_token_id to mask_weight explicitly
mixed_probs[mask_token_id] = mask_weight
noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs)
for idx, val in zip(indices, noise):
noised[idx] = val
return noised
@spaces.GPU
def generate_diffusion_text(input_ids):
with torch.no_grad():
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
logits = model(input_ids=input_tensor)["logits"]
logits = logits.clamp(min=-1e4, max=1e4)
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
probs = torch.clamp(probs, min=1e-8, max=1.0)
assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
assert (probs >= 0).all(), "Negative probs!"
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
# Extract confidence of selected tokens
conf = probs[range(len(sampled)), sampled].cpu().numpy()
return sampled, conf
# --- Inference Wrapper ---
def diffusion_chat(question, eot_weight, mask_weight, max_it, pause_length, sharpness, clustering, noise_start, use_confidence_noising, noise_clipping):
placeholder = "What do you know about the city of New York?"
if question.strip() == "":
question = placeholder
print('started generation')
prompt = f"User: {question}\nAssistant:"
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
answer_start = find_answer_start(input_ids, assistant_marker_ids)
if answer_start is None:
yield "Error: Could not find Assistant marker in input."
return
if len(input_ids) < 256:
input_ids += [pad_token] * (256 - len(input_ids))
else:
input_ids = input_ids[:256]
ori_input_tokens = input_ids
current_tokens, just_noised_indices = noisify_answer(
input_ids, answer_start, threshold=1.0, eot_weight=eot_weight, mask_weight=mask_weight, clustering=clustering, noise_start = 1.0,
)
yield f"Iteration 0 (initial noise):
" + tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).replace('\n', '
')
time.sleep(pause_length)
last_tokens = []
prev_decoded_tokens = []
for i in range(max_it):
print('Generating output')
# Model step
generated_tokens, confidences = generate_diffusion_text(current_tokens)
# Save full output for noising step
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
# --- GREEN HIGHLIGHT ---
decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
highlighted = []
for j, tok in enumerate(decoded_tokens):
tok_id = tokenizer.convert_tokens_to_ids(tok)
if tok_id == eot_token_id:
continue
token_str = tokenizer.convert_tokens_to_string([tok])
if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
highlighted.append(f'{token_str}')
else:
highlighted.append(token_str)
prev_decoded_tokens = decoded_tokens
yield f"Iteration {i+1}/{max_it} (after generation):
" + "".join(highlighted).replace('\n', '
')
time.sleep(pause_length)
# --- Early stopping ---
last_tokens.append(current_tokens)
if len(last_tokens) > 3:
last_tokens.pop(0)
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
yield f"Stopped early after {i+1} iterations."
break
previous_tokens = current_tokens.copy()
# --- NOISING STEP ---
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
if use_confidence_noising:
noised_answer = confidence_guided_noising(
current_tokens, answer_start, confidences, noise_clipping, threshold=threshold, eot_weight=eot_weight, mask_weight=mask_weight, noise_start=noise_start
)
just_noised_indices = []
else:
noised_answer, just_noised_indices = noisify_answer(
current_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, mask_weight=mask_weight, clustering=clustering, noise_start = noise_start,
)
# Compose full input again: prompt + noised answer
current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
# --- RED HIGHLIGHT ---
decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
highlighted = []
for j, tok in enumerate(decoded_tokens):
tok_id = tokenizer.convert_tokens_to_ids(tok)
# if tok_id == eot_token_id:
# continue
token_str = tokenizer.convert_tokens_to_string([tok])
abs_idx = answer_start + j
if abs_idx in just_noised_indices:
highlighted.append(f'{token_str}')
else:
highlighted.append(token_str)
yield f"Iteration {i+1}/{max_it} (before noising):
" + "".join(highlighted).replace('\n', '
')
time.sleep(pause_length)
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
final_output = tokenizer.convert_tokens_to_string(final_tokens)
print(final_output)
yield f"Final Output (after {i+1} iterations):
" + final_output.replace('\n', '
')
# --- Gradio Interface ---
print("Loading model...")
model = load_model()
print("✅ Model loaded.")
demo = gr.Interface(
fn=diffusion_chat,
inputs=[
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
gr.Slider(0, 1, value=0.5, step=0.05, label="↓ = longer answers (EOT weight)"),
gr.Slider(0, 1, value=0.5, step=0.05, label="↓ = more random answers (MASK weight)"),
gr.Slider(1, 512, value=32, step=1, label="↑ = more iterations"),
gr.Slider(0.01, 5, value=0.01, step=0.01, label="↑ = longer pause (for visualization)"),
gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"),
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="↑ = more clustered noising (fewer, larger edits)"),
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="↑ = more noise (noise start)"),
gr.Checkbox(value=False, label="Use confidence-guided noising"),
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
],
outputs=[gr.HTML(label="Diffusion Output")],
title="Diffusion Language Model Chat",
theme="default",
description="This interface runs a diffusion-based language model to generate answers progressively."
)
demo.launch(share=True, allowed_paths=["."], ssr_mode=False)