|
import lightning as pl |
|
from lightning.pytorch.callbacks import ModelCheckpoint |
|
from lightning.pytorch.loggers import CSVLogger |
|
from datetime import datetime |
|
from pathlib import Path |
|
|
|
from pet_seg_core.config import PetSegTrainConfig |
|
from pet_seg_core.data import train_dataloader, val_dataloader |
|
from pet_seg_core.model import UNet |
|
|
|
def train(): |
|
curr_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f') |
|
results_folder = f"results/{curr_time}" |
|
Path(results_folder).mkdir(parents=True, exist_ok=True) |
|
with open(f"{results_folder}/description.txt", "w") as f: |
|
f.write(PetSegTrainConfig.DESCRIPTION_TEXT) |
|
|
|
logger = CSVLogger(save_dir="", name=results_folder, version="") |
|
checkpoint_callback = ModelCheckpoint( |
|
dirpath=results_folder, |
|
save_top_k=-1, |
|
) |
|
trainer = pl.Trainer( |
|
max_epochs=PetSegTrainConfig.EPOCHS, fast_dev_run=PetSegTrainConfig.FAST_DEV_RUN, logger=logger, callbacks=[checkpoint_callback], gradient_clip_val=1.0 |
|
) |
|
model = UNet(3, 3, channels_list=PetSegTrainConfig.CHANNELS_LIST, depthwise_sep=PetSegTrainConfig.DEPTHWISE_SEP) |
|
|
|
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) |
|
|