{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from huggingface_hub import hf_hub_download\n", "from transformers import AutoTokenizer, T5Tokenizer\n", "\n", "from model.distilbert import DistilBertClassificationModel\n", "from model.scibert import SciBertClassificationModel\n", "from model.llama import LlamaClassificationModel\n", "from model.t5 import T5ClassificationModel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Selection\n", "Uncomment desired `repo_id` and corresponding `model` and input type." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Baseline\n", "# repo_id = \"ppak10/defect-classification-distilbert-baseline-25-epochs\"\n", "# repo_id = \"ppak10/defect-classification-scibert-baseline-25-epochs\"\n", "# repo_id = \"ppak10/defect-classification-llama-baseline-25-epochs\"\n", "# repo_id = \"ppak10/defect-classification-t5-baseline-25-epochs\"\n", "\n", "# Prompt \n", "# repo_id = \"ppak10/defect-classification-distilbert-prompt-02-epochs\"\n", "# repo_id = \"ppak10/defect-classification-scibert-prompt-02-epochs\"\n", "# repo_id = \"ppak10/defect-classification-llama-prompt-02-epochs\"\n", "repo_id = \"ppak10/defect-classification-t5-prompt-02-epochs\"\n", "\n", "# Initialize the model\n", "# model = DistilBertClassificationModel(repo_id)\n", "# model = SciBertClassificationModel(repo_id)\n", "# model = LlamaClassificationModel()\n", "model = T5ClassificationModel(repo_id)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Uncomment for DistilBERT, SciBERT, and Llama\n", "# tokenizer = AutoTokenizer.from_pretrained(repo_id)\n", "\n", "# Uncomment for T5\n", "tokenizer = T5Tokenizer.from_pretrained(repo_id)\n", "\n", "# Loads classification head weights\n", "classification_head_path = hf_hub_download(\n", " repo_id=repo_id,\n", " repo_type=\"model\",\n", " filename=\"classification_head.pt\"\n", ")\n", "\n", "model.classifier.load_state_dict(torch.load(classification_head_path, map_location=torch.device(\"cpu\")))\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Baseline\n", "# text = \"Ti-6Al-4V[SEP]280.0 W[SEP]400.0 mm/s[SEP]100.0 microns[SEP]50.0 microns[SEP]100.0 microns\"\n", "\n", "# Prompt\n", "text = \"What are the likely imperfections that occur in Ti-6Al-4V L-PBF builds at 280.0 W, given a 100.0 microns beam diameter, a 400.0 mm/s scan speed, a 100.0 microns hatch spacing, and a 50.0 microns layer height?\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Tokenize inputs \n", "inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=\"max_length\", max_length=256)\n", "\n", "# For SciBERT specific case. \n", "inputs_kwargs = {}\n", "for key, value in inputs.items():\n", " if key not in [\"token_type_ids\"]:\n", " inputs_kwargs[key] = value\n", "\n", "# Perform inference\n", "outputs = model(**inputs_kwargs)\n", "\n", "# Extract logits and apply sigmoid activation for multi-label classification\n", "probs = torch.sigmoid(outputs[\"logits\"])\n", "\n", "# Convert probabilities to one-hot encoded labels\n", "preds = (probs > 0.5).int().squeeze()\n", "\n", "# One hot encoded classifications\n", "classifications = [\"None\", \"Keyhole\", \"Lack of Fusion\", \"Balling\"]\n", " \n", "print([classifications[index] for index, encoding in enumerate(preds) if encoding == 1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }