ppak10 commited on
Commit
58558c8
·
1 Parent(s): 6a286a9

Updates inference notebook.

Browse files
Files changed (1) hide show
  1. model/notebooks/inference.ipynb +11 -8
model/notebooks/inference.ipynb CHANGED
@@ -16,7 +16,7 @@
16
  "import torch\n",
17
  "\n",
18
  "from huggingface_hub import hf_hub_download\n",
19
- "from transformers import AutoTokenizer\n",
20
  "\n",
21
  "from model.distilbert import DistilBertClassificationModel\n",
22
  "from model.scibert import SciBertClassificationModel\n",
@@ -39,7 +39,7 @@
39
  "outputs": [],
40
  "source": [
41
  "# Baseline\n",
42
- "repo_id = \"ppak10/defect-classification-distilbert-baseline-25-epochs\"\n",
43
  "# repo_id = \"ppak10/defect-classification-scibert-baseline-25-epochs\"\n",
44
  "# repo_id = \"ppak10/defect-classification-llama-baseline-25-epochs\"\n",
45
  "# repo_id = \"ppak10/defect-classification-t5-baseline-25-epochs\"\n",
@@ -48,7 +48,7 @@
48
  "# repo_id = \"ppak10/defect-classification-distilbert-prompt-02-epochs\"\n",
49
  "# repo_id = \"ppak10/defect-classification-scibert-prompt-02-epochs\"\n",
50
  "# repo_id = \"ppak10/defect-classification-llama-prompt-02-epochs\"\n",
51
- "# repo_id = \"ppak10/defect-classification-t5-prompt-02-epochs\"\n",
52
  "\n",
53
  "# Initialize the model\n",
54
  "model = DistilBertClassificationModel(repo_id)\n",
@@ -63,8 +63,11 @@
63
  "metadata": {},
64
  "outputs": [],
65
  "source": [
66
- "# Load the tokenizer\n",
67
- "tokenizer = AutoTokenizer.from_pretrained(repo_id)\n",
 
 
 
68
  "\n",
69
  "# Loads classification head weights\n",
70
  "classification_head_path = hf_hub_download(\n",
@@ -84,10 +87,10 @@
84
  "outputs": [],
85
  "source": [
86
  "# Baseline\n",
87
- "text = \"Ti-6Al-4V[SEP]280.0 W[SEP]400.0 mm/s[SEP]100.0 microns[SEP]50.0 microns[SEP]100.0 microns\"\n",
88
  "\n",
89
  "# Prompt\n",
90
- "# text = \"What are the likely imperfections that occur in Ti-6Al-4V L-PBF builds at 280.0 W, given a 100.0 microns beam diameter, a 400.0 mm/s scan speed, a 100.0 microns hatch spacing, and a 50.0 microns layer height?\""
91
  ]
92
  },
93
  {
@@ -99,7 +102,7 @@
99
  "# Tokenize inputs \n",
100
  "inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=\"max_length\", max_length=256)\n",
101
  "\n",
102
- "# For scibert\n",
103
  "inputs_kwargs = {}\n",
104
  "for key, value in inputs.items():\n",
105
  " if key not in [\"token_type_ids\"]:\n",
 
16
  "import torch\n",
17
  "\n",
18
  "from huggingface_hub import hf_hub_download\n",
19
+ "from transformers import AutoTokenizer, T5Tokenizer\n",
20
  "\n",
21
  "from model.distilbert import DistilBertClassificationModel\n",
22
  "from model.scibert import SciBertClassificationModel\n",
 
39
  "outputs": [],
40
  "source": [
41
  "# Baseline\n",
42
+ "# repo_id = \"ppak10/defect-classification-distilbert-baseline-25-epochs\"\n",
43
  "# repo_id = \"ppak10/defect-classification-scibert-baseline-25-epochs\"\n",
44
  "# repo_id = \"ppak10/defect-classification-llama-baseline-25-epochs\"\n",
45
  "# repo_id = \"ppak10/defect-classification-t5-baseline-25-epochs\"\n",
 
48
  "# repo_id = \"ppak10/defect-classification-distilbert-prompt-02-epochs\"\n",
49
  "# repo_id = \"ppak10/defect-classification-scibert-prompt-02-epochs\"\n",
50
  "# repo_id = \"ppak10/defect-classification-llama-prompt-02-epochs\"\n",
51
+ "repo_id = \"ppak10/defect-classification-t5-prompt-02-epochs\"\n",
52
  "\n",
53
  "# Initialize the model\n",
54
  "model = DistilBertClassificationModel(repo_id)\n",
 
63
  "metadata": {},
64
  "outputs": [],
65
  "source": [
66
+ "# Uncomment for DistilBERT, SciBERT, and Llama\n",
67
+ "# tokenizer = AutoTokenizer.from_pretrained(repo_id)\n",
68
+ "\n",
69
+ "# Uncomment for T5\n",
70
+ "tokenizer = T5Tokenizer.from_pretrained(repo_id)\n",
71
  "\n",
72
  "# Loads classification head weights\n",
73
  "classification_head_path = hf_hub_download(\n",
 
87
  "outputs": [],
88
  "source": [
89
  "# Baseline\n",
90
+ "# text = \"Ti-6Al-4V[SEP]280.0 W[SEP]400.0 mm/s[SEP]100.0 microns[SEP]50.0 microns[SEP]100.0 microns\"\n",
91
  "\n",
92
  "# Prompt\n",
93
+ "text = \"What are the likely imperfections that occur in Ti-6Al-4V L-PBF builds at 280.0 W, given a 100.0 microns beam diameter, a 400.0 mm/s scan speed, a 100.0 microns hatch spacing, and a 50.0 microns layer height?\""
94
  ]
95
  },
96
  {
 
102
  "# Tokenize inputs \n",
103
  "inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=\"max_length\", max_length=256)\n",
104
  "\n",
105
+ "# For SciBERT specific case. \n",
106
  "inputs_kwargs = {}\n",
107
  "for key, value in inputs.items():\n",
108
  " if key not in [\"token_type_ids\"]:\n",