File size: 1,203 Bytes
8aed5f0
 
 
 
8fdfa17
8aed5f0
 
 
 
 
 
f783fb4
8aed5f0
8fdfa17
8aed5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from datetime import datetime
from pathlib import Path

from pet_seg_core.config import PetSegTrainConfig
from pet_seg_core.data import train_dataloader, val_dataloader
from pet_seg_core.model import UNet

def train():
    curr_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')
    results_folder = f"results/{curr_time}"
    Path(results_folder).mkdir(parents=True, exist_ok=True)
    with open(f"{results_folder}/description.txt", "w") as f:
        f.write(PetSegTrainConfig.DESCRIPTION_TEXT)

    logger = CSVLogger(save_dir="", name=results_folder, version="")
    checkpoint_callback = ModelCheckpoint(
        dirpath=results_folder,
        save_top_k=-1,
    )
    trainer = pl.Trainer(
        max_epochs=PetSegTrainConfig.EPOCHS, fast_dev_run=PetSegTrainConfig.FAST_DEV_RUN, logger=logger, callbacks=[checkpoint_callback], gradient_clip_val=1.0
    )
    model = UNet(3, 3, channels_list=PetSegTrainConfig.CHANNELS_LIST, depthwise_sep=PetSegTrainConfig.DEPTHWISE_SEP)

    trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)