File size: 4,654 Bytes
1b74e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58558c8
1b74e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58558c8
1b74e0a
 
 
 
 
 
 
 
58558c8
1b74e0a
 
68fc8ba
1b74e0a
 
68fc8ba
1b74e0a
 
 
 
 
 
 
 
58558c8
 
 
 
 
1b74e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58558c8
1b74e0a
 
58558c8
1b74e0a
 
 
 
 
 
 
 
 
 
 
58558c8
1b74e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from huggingface_hub import hf_hub_download\n",
    "from transformers import AutoTokenizer, T5Tokenizer\n",
    "\n",
    "from model.distilbert import DistilBertClassificationModel\n",
    "from model.scibert import SciBertClassificationModel\n",
    "from model.llama import LlamaClassificationModel\n",
    "from model.t5 import T5ClassificationModel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Selection\n",
    "Uncomment desired `repo_id` and corresponding `model` and input type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Baseline\n",
    "# repo_id = \"ppak10/defect-classification-distilbert-baseline-25-epochs\"\n",
    "# repo_id = \"ppak10/defect-classification-scibert-baseline-25-epochs\"\n",
    "# repo_id = \"ppak10/defect-classification-llama-baseline-25-epochs\"\n",
    "# repo_id = \"ppak10/defect-classification-t5-baseline-25-epochs\"\n",
    "\n",
    "# Prompt \n",
    "# repo_id = \"ppak10/defect-classification-distilbert-prompt-02-epochs\"\n",
    "# repo_id = \"ppak10/defect-classification-scibert-prompt-02-epochs\"\n",
    "# repo_id = \"ppak10/defect-classification-llama-prompt-02-epochs\"\n",
    "repo_id = \"ppak10/defect-classification-t5-prompt-02-epochs\"\n",
    "\n",
    "# Initialize the model\n",
    "# model = DistilBertClassificationModel(repo_id)\n",
    "# model = SciBertClassificationModel(repo_id)\n",
    "# model = LlamaClassificationModel()\n",
    "model = T5ClassificationModel(repo_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment for DistilBERT, SciBERT, and Llama\n",
    "# tokenizer = AutoTokenizer.from_pretrained(repo_id)\n",
    "\n",
    "# Uncomment for T5\n",
    "tokenizer = T5Tokenizer.from_pretrained(repo_id)\n",
    "\n",
    "# Loads classification head weights\n",
    "classification_head_path = hf_hub_download(\n",
    "    repo_id=repo_id,\n",
    "    repo_type=\"model\",\n",
    "    filename=\"classification_head.pt\"\n",
    ")\n",
    "\n",
    "model.classifier.load_state_dict(torch.load(classification_head_path, map_location=torch.device(\"cpu\")))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Baseline\n",
    "# text = \"Ti-6Al-4V[SEP]280.0 W[SEP]400.0 mm/s[SEP]100.0 microns[SEP]50.0 microns[SEP]100.0 microns\"\n",
    "\n",
    "# Prompt\n",
    "text = \"What are the likely imperfections that occur in Ti-6Al-4V L-PBF builds at 280.0 W, given a 100.0 microns beam diameter, a 400.0 mm/s scan speed, a 100.0 microns hatch spacing, and a 50.0 microns layer height?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Tokenize inputs \n",
    "inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=\"max_length\", max_length=256)\n",
    "\n",
    "# For SciBERT specific case. \n",
    "inputs_kwargs = {}\n",
    "for key, value in inputs.items():\n",
    "    if key not in [\"token_type_ids\"]:\n",
    "        inputs_kwargs[key] = value\n",
    "\n",
    "# Perform inference\n",
    "outputs = model(**inputs_kwargs)\n",
    "\n",
    "# Extract logits and apply sigmoid activation for multi-label classification\n",
    "probs = torch.sigmoid(outputs[\"logits\"])\n",
    "\n",
    "# Convert probabilities to one-hot encoded labels\n",
    "preds = (probs > 0.5).int().squeeze()\n",
    "\n",
    "# One hot encoded classifications\n",
    "classifications = [\"None\", \"Keyhole\", \"Lack of Fusion\", \"Balling\"]\n",
    "    \n",
    "print([classifications[index] for index, encoding in enumerate(preds) if encoding == 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}