Spaces:
Sleeping
Sleeping
Updates inference notebook.
Browse files
model/notebooks/inference.ipynb
CHANGED
@@ -16,7 +16,7 @@
|
|
16 |
"import torch\n",
|
17 |
"\n",
|
18 |
"from huggingface_hub import hf_hub_download\n",
|
19 |
-
"from transformers import AutoTokenizer\n",
|
20 |
"\n",
|
21 |
"from model.distilbert import DistilBertClassificationModel\n",
|
22 |
"from model.scibert import SciBertClassificationModel\n",
|
@@ -39,7 +39,7 @@
|
|
39 |
"outputs": [],
|
40 |
"source": [
|
41 |
"# Baseline\n",
|
42 |
-
"repo_id = \"ppak10/defect-classification-distilbert-baseline-25-epochs\"\n",
|
43 |
"# repo_id = \"ppak10/defect-classification-scibert-baseline-25-epochs\"\n",
|
44 |
"# repo_id = \"ppak10/defect-classification-llama-baseline-25-epochs\"\n",
|
45 |
"# repo_id = \"ppak10/defect-classification-t5-baseline-25-epochs\"\n",
|
@@ -48,7 +48,7 @@
|
|
48 |
"# repo_id = \"ppak10/defect-classification-distilbert-prompt-02-epochs\"\n",
|
49 |
"# repo_id = \"ppak10/defect-classification-scibert-prompt-02-epochs\"\n",
|
50 |
"# repo_id = \"ppak10/defect-classification-llama-prompt-02-epochs\"\n",
|
51 |
-
"
|
52 |
"\n",
|
53 |
"# Initialize the model\n",
|
54 |
"model = DistilBertClassificationModel(repo_id)\n",
|
@@ -63,8 +63,11 @@
|
|
63 |
"metadata": {},
|
64 |
"outputs": [],
|
65 |
"source": [
|
66 |
-
"#
|
67 |
-
"tokenizer = AutoTokenizer.from_pretrained(repo_id)\n",
|
|
|
|
|
|
|
68 |
"\n",
|
69 |
"# Loads classification head weights\n",
|
70 |
"classification_head_path = hf_hub_download(\n",
|
@@ -84,10 +87,10 @@
|
|
84 |
"outputs": [],
|
85 |
"source": [
|
86 |
"# Baseline\n",
|
87 |
-
"text = \"Ti-6Al-4V[SEP]280.0 W[SEP]400.0 mm/s[SEP]100.0 microns[SEP]50.0 microns[SEP]100.0 microns\"\n",
|
88 |
"\n",
|
89 |
"# Prompt\n",
|
90 |
-
"
|
91 |
]
|
92 |
},
|
93 |
{
|
@@ -99,7 +102,7 @@
|
|
99 |
"# Tokenize inputs \n",
|
100 |
"inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=\"max_length\", max_length=256)\n",
|
101 |
"\n",
|
102 |
-
"# For
|
103 |
"inputs_kwargs = {}\n",
|
104 |
"for key, value in inputs.items():\n",
|
105 |
" if key not in [\"token_type_ids\"]:\n",
|
|
|
16 |
"import torch\n",
|
17 |
"\n",
|
18 |
"from huggingface_hub import hf_hub_download\n",
|
19 |
+
"from transformers import AutoTokenizer, T5Tokenizer\n",
|
20 |
"\n",
|
21 |
"from model.distilbert import DistilBertClassificationModel\n",
|
22 |
"from model.scibert import SciBertClassificationModel\n",
|
|
|
39 |
"outputs": [],
|
40 |
"source": [
|
41 |
"# Baseline\n",
|
42 |
+
"# repo_id = \"ppak10/defect-classification-distilbert-baseline-25-epochs\"\n",
|
43 |
"# repo_id = \"ppak10/defect-classification-scibert-baseline-25-epochs\"\n",
|
44 |
"# repo_id = \"ppak10/defect-classification-llama-baseline-25-epochs\"\n",
|
45 |
"# repo_id = \"ppak10/defect-classification-t5-baseline-25-epochs\"\n",
|
|
|
48 |
"# repo_id = \"ppak10/defect-classification-distilbert-prompt-02-epochs\"\n",
|
49 |
"# repo_id = \"ppak10/defect-classification-scibert-prompt-02-epochs\"\n",
|
50 |
"# repo_id = \"ppak10/defect-classification-llama-prompt-02-epochs\"\n",
|
51 |
+
"repo_id = \"ppak10/defect-classification-t5-prompt-02-epochs\"\n",
|
52 |
"\n",
|
53 |
"# Initialize the model\n",
|
54 |
"model = DistilBertClassificationModel(repo_id)\n",
|
|
|
63 |
"metadata": {},
|
64 |
"outputs": [],
|
65 |
"source": [
|
66 |
+
"# Uncomment for DistilBERT, SciBERT, and Llama\n",
|
67 |
+
"# tokenizer = AutoTokenizer.from_pretrained(repo_id)\n",
|
68 |
+
"\n",
|
69 |
+
"# Uncomment for T5\n",
|
70 |
+
"tokenizer = T5Tokenizer.from_pretrained(repo_id)\n",
|
71 |
"\n",
|
72 |
"# Loads classification head weights\n",
|
73 |
"classification_head_path = hf_hub_download(\n",
|
|
|
87 |
"outputs": [],
|
88 |
"source": [
|
89 |
"# Baseline\n",
|
90 |
+
"# text = \"Ti-6Al-4V[SEP]280.0 W[SEP]400.0 mm/s[SEP]100.0 microns[SEP]50.0 microns[SEP]100.0 microns\"\n",
|
91 |
"\n",
|
92 |
"# Prompt\n",
|
93 |
+
"text = \"What are the likely imperfections that occur in Ti-6Al-4V L-PBF builds at 280.0 W, given a 100.0 microns beam diameter, a 400.0 mm/s scan speed, a 100.0 microns hatch spacing, and a 50.0 microns layer height?\""
|
94 |
]
|
95 |
},
|
96 |
{
|
|
|
102 |
"# Tokenize inputs \n",
|
103 |
"inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=\"max_length\", max_length=256)\n",
|
104 |
"\n",
|
105 |
+
"# For SciBERT specific case. \n",
|
106 |
"inputs_kwargs = {}\n",
|
107 |
"for key, value in inputs.items():\n",
|
108 |
" if key not in [\"token_type_ids\"]:\n",
|