Spaces:
Sleeping
Sleeping
File size: 4,654 Bytes
1b74e0a 58558c8 1b74e0a 58558c8 1b74e0a 58558c8 1b74e0a 68fc8ba 1b74e0a 68fc8ba 1b74e0a 58558c8 1b74e0a 58558c8 1b74e0a 58558c8 1b74e0a 58558c8 1b74e0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"from huggingface_hub import hf_hub_download\n",
"from transformers import AutoTokenizer, T5Tokenizer\n",
"\n",
"from model.distilbert import DistilBertClassificationModel\n",
"from model.scibert import SciBertClassificationModel\n",
"from model.llama import LlamaClassificationModel\n",
"from model.t5 import T5ClassificationModel"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Selection\n",
"Uncomment desired `repo_id` and corresponding `model` and input type."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Baseline\n",
"# repo_id = \"ppak10/defect-classification-distilbert-baseline-25-epochs\"\n",
"# repo_id = \"ppak10/defect-classification-scibert-baseline-25-epochs\"\n",
"# repo_id = \"ppak10/defect-classification-llama-baseline-25-epochs\"\n",
"# repo_id = \"ppak10/defect-classification-t5-baseline-25-epochs\"\n",
"\n",
"# Prompt \n",
"# repo_id = \"ppak10/defect-classification-distilbert-prompt-02-epochs\"\n",
"# repo_id = \"ppak10/defect-classification-scibert-prompt-02-epochs\"\n",
"# repo_id = \"ppak10/defect-classification-llama-prompt-02-epochs\"\n",
"repo_id = \"ppak10/defect-classification-t5-prompt-02-epochs\"\n",
"\n",
"# Initialize the model\n",
"# model = DistilBertClassificationModel(repo_id)\n",
"# model = SciBertClassificationModel(repo_id)\n",
"# model = LlamaClassificationModel()\n",
"model = T5ClassificationModel(repo_id)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Uncomment for DistilBERT, SciBERT, and Llama\n",
"# tokenizer = AutoTokenizer.from_pretrained(repo_id)\n",
"\n",
"# Uncomment for T5\n",
"tokenizer = T5Tokenizer.from_pretrained(repo_id)\n",
"\n",
"# Loads classification head weights\n",
"classification_head_path = hf_hub_download(\n",
" repo_id=repo_id,\n",
" repo_type=\"model\",\n",
" filename=\"classification_head.pt\"\n",
")\n",
"\n",
"model.classifier.load_state_dict(torch.load(classification_head_path, map_location=torch.device(\"cpu\")))\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Baseline\n",
"# text = \"Ti-6Al-4V[SEP]280.0 W[SEP]400.0 mm/s[SEP]100.0 microns[SEP]50.0 microns[SEP]100.0 microns\"\n",
"\n",
"# Prompt\n",
"text = \"What are the likely imperfections that occur in Ti-6Al-4V L-PBF builds at 280.0 W, given a 100.0 microns beam diameter, a 400.0 mm/s scan speed, a 100.0 microns hatch spacing, and a 50.0 microns layer height?\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Tokenize inputs \n",
"inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=\"max_length\", max_length=256)\n",
"\n",
"# For SciBERT specific case. \n",
"inputs_kwargs = {}\n",
"for key, value in inputs.items():\n",
" if key not in [\"token_type_ids\"]:\n",
" inputs_kwargs[key] = value\n",
"\n",
"# Perform inference\n",
"outputs = model(**inputs_kwargs)\n",
"\n",
"# Extract logits and apply sigmoid activation for multi-label classification\n",
"probs = torch.sigmoid(outputs[\"logits\"])\n",
"\n",
"# Convert probabilities to one-hot encoded labels\n",
"preds = (probs > 0.5).int().squeeze()\n",
"\n",
"# One hot encoded classifications\n",
"classifications = [\"None\", \"Keyhole\", \"Lack of Fusion\", \"Balling\"]\n",
" \n",
"print([classifications[index] for index, encoding in enumerate(preds) if encoding == 1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|