# Inference

In [None]:
import torch

from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, T5Tokenizer

from model.distilbert import DistilBertClassificationModel
from model.scibert import SciBertClassificationModel
from model.llama import LlamaClassificationModel
from model.t5 import T5ClassificationModel

## Model Selection
Uncomment desired `repo_id` and corresponding `model` and input type.

In [2]:
# Baseline
# repo_id = "ppak10/defect-classification-distilbert-baseline-25-epochs"
# repo_id = "ppak10/defect-classification-scibert-baseline-25-epochs"
# repo_id = "ppak10/defect-classification-llama-baseline-25-epochs"
# repo_id = "ppak10/defect-classification-t5-baseline-25-epochs"

# Prompt 
# repo_id = "ppak10/defect-classification-distilbert-prompt-02-epochs"
# repo_id = "ppak10/defect-classification-scibert-prompt-02-epochs"
# repo_id = "ppak10/defect-classification-llama-prompt-02-epochs"
repo_id = "ppak10/defect-classification-t5-prompt-02-epochs"

# Initialize the model
# model = DistilBertClassificationModel(repo_id)
# model = SciBertClassificationModel(repo_id)
# model = LlamaClassificationModel()
model = T5ClassificationModel(repo_id)

In [None]:
# Uncomment for DistilBERT, SciBERT, and Llama
# tokenizer = AutoTokenizer.from_pretrained(repo_id)

# Uncomment for T5
tokenizer = T5Tokenizer.from_pretrained(repo_id)

# Loads classification head weights
classification_head_path = hf_hub_download(
 repo_id=repo_id,
 repo_type="model",
 filename="classification_head.pt"
)

model.classifier.load_state_dict(torch.load(classification_head_path, map_location=torch.device("cpu")))
model.eval()

In [None]:
# Baseline
# text = "Ti-6Al-4V[SEP]280.0 W[SEP]400.0 mm/s[SEP]100.0 microns[SEP]50.0 microns[SEP]100.0 microns"

# Prompt
text = "What are the likely imperfections that occur in Ti-6Al-4V L-PBF builds at 280.0 W, given a 100.0 microns beam diameter, a 400.0 mm/s scan speed, a 100.0 microns hatch spacing, and a 50.0 microns layer height?"

In [None]:
# Tokenize inputs 
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=256)

# For SciBERT specific case. 
inputs_kwargs = {}
for key, value in inputs.items():
 if key not in ["token_type_ids"]:
 inputs_kwargs[key] = value

# Perform inference
outputs = model(**inputs_kwargs)

# Extract logits and apply sigmoid activation for multi-label classification
probs = torch.sigmoid(outputs["logits"])

# Convert probabilities to one-hot encoded labels
preds = (probs > 0.5).int().squeeze()

# One hot encoded classifications
classifications = ["None", "Keyhole", "Lack of Fusion", "Balling"]
 
print([classifications[index] for index, encoding in enumerate(preds) if encoding == 1])