File size: 2,233 Bytes
8aed5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torchvision
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision import transforms as T

from pet_seg_core.config import PetSegTrainConfig

# Define the transforms
transform = T.Compose(
    [
        T.ToTensor(),
        T.Resize((128, 128), interpolation=T.InterpolationMode.NEAREST),
    ]
)

print(f"Downloading data")

# Download the dataset
train_val_ds = torchvision.datasets.OxfordIIITPet(
    root=PetSegTrainConfig.TRAIN_VAL_TEST_DATA_PATH,
    split="trainval",
    target_types="segmentation",
    transform=transform,
    target_transform=transform,
    download=True,
)

print(f"Downloaded data")

# Randomly sample some samples
if PetSegTrainConfig.TOTAL_SAMPLES > 0:
    train_val_ds = torch.utils.data.Subset(
        train_val_ds, torch.randperm(len(train_val_ds))[:PetSegTrainConfig.TOTAL_SAMPLES]
    )

# Split the dataset into train val and test
train_ds, val_ds = torch.utils.data.random_split(
    train_val_ds,
    [int(0.8 * len(train_val_ds)), len(train_val_ds) - int(0.8 * len(train_val_ds))],
)

test_ds, val_ds = torch.utils.data.random_split(
    val_ds,
    [int(0.5 * len(val_ds)), len(val_ds) - int(0.5 * len(val_ds))],
)

train_dataloader = DataLoader(
    train_ds,  # The training samples.
    sampler=RandomSampler(train_ds),  # Select batches randomly
    batch_size=PetSegTrainConfig.BATCH_SIZE,  # Trains with this batch size.
    num_workers=3,
    persistent_workers=True,
)

# For validation the order doesn't matter, so we'll just read them sequentially.
val_dataloader = DataLoader(
    val_ds,  # The validation samples.
    sampler=SequentialSampler(val_ds),  # Pull out batches sequentially.
    batch_size=PetSegTrainConfig.BATCH_SIZE,  # Evaluate with this batch size.
    num_workers=3,
    persistent_workers=True,
)

# For validation the order doesn't matter, so we'll just read them sequentially.
test_dataloader = DataLoader(
            test_ds, # The validation samples.
            sampler = SequentialSampler(test_ds), # Pull out batches sequentially.
            batch_size = PetSegTrainConfig.BATCH_SIZE, # Evaluate with this batch size.
            num_workers=3,
            persistent_workers=True,
        )