import lightning as pl import torch import torchvision.transforms.functional as TF from torch import nn from torchmetrics.functional.segmentation import mean_iou from torchmetrics.classification import MulticlassConfusionMatrix from pet_seg_core.config import PetSegTrainConfig from functools import partial class DoubleConvOriginal(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConvOriginal, self).__init__() self.double_conv = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False, ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False, ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, x): return self.double_conv(x) class DoubleConvDepthwiseSep(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConvDepthwiseSep, self).__init__() self.double_conv = nn.Sequential( nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False, ), nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1, groups=out_channels, bias=False, ), nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, x): return self.double_conv(x) class UNet(pl.LightningModule): def __init__( self, in_channels, out_channels, channels_list=[64, 128, 256, 512], depthwise_sep=False, ): super(UNet, self).__init__() self.save_hyperparameters() self.in_channels = in_channels self.out_channels = out_channels self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() self.pool = nn.MaxPool2d(kernel_size=2, stride=2) if depthwise_sep: DoubleConv = DoubleConvDepthwiseSep else: DoubleConv = DoubleConvOriginal # Encoder for channels in channels_list: self.encoder.append(DoubleConv(in_channels, channels)) in_channels = channels # Decoder for channels in channels_list[::-1]: self.decoder.append( nn.ConvTranspose2d(channels * 2, channels, kernel_size=2, stride=2) ) self.decoder.append(DoubleConv(channels * 2, channels)) self.bottleneck = DoubleConv(channels_list[-1], channels_list[-1] * 2) self.out = nn.Conv2d(channels_list[0], out_channels, kernel_size=1) self.loss_fn = nn.CrossEntropyLoss() self.iou = partial(mean_iou, num_classes=out_channels) self.conf_mat = MulticlassConfusionMatrix(num_classes=out_channels) def forward(self, x): skip_connections = [] for i, enc_block in enumerate(self.encoder): x = enc_block(x) skip_connections.append(x) x = self.pool(x) x = self.bottleneck(x) skip_connections = skip_connections[::-1] for i in range(0, len(self.decoder), 2): x = self.decoder[i](x) skip_connection = skip_connections[i // 2] if x.shape != skip_connection.shape: x = TF.resize(x, size=skip_connection.shape[2:]) concat_skip = torch.cat( (skip_connection, x), dim=1 ) # Concatenate along the channel dimension x = self.decoder[i + 1](concat_skip) x = self.out(x) return x def _common_step(self, batch, batch_idx, prefix): x, y = batch y = (y * 255 - 1).long().squeeze(1) # move to dataloader y_hat = self(x) loss = self.loss_fn(y_hat, y) self.log(f"{prefix}_loss", loss.item(), prog_bar=True) y_hat_argmax = torch.argmax(y_hat, dim=1) y_hat_argmax_onehot = torch.nn.functional.one_hot(y_hat_argmax, num_classes=self.out_channels).permute(0, 3, 1, 2) y_onehot = torch.nn.functional.one_hot(y, num_classes=self.out_channels).permute(0, 3, 1, 2) iou = self.iou(y_hat_argmax_onehot, y_onehot) # self.log(f"{prefix}_iou", iou.mean().item(), prog_bar=True) self.conf_mat.update(y_hat_argmax, y) return y_hat, loss def training_step(self, batch, batch_idx): y_hat, loss = self._common_step(batch, batch_idx, "train") return loss def validation_step(self, batch, batch_idx): y_hat, loss = self._common_step(batch, batch_idx, "val") def test_step(self, batch, batch_idx): y_hat, loss = self._common_step(batch, batch_idx, "test") def _common_on_epoch_end(self, prefix): confmat = self.conf_mat.compute() for i in range(self.out_channels): for j in range(self.out_channels): self.log(f'{prefix}_confmat_true={i}_pred={j}', confmat[i][j].item(), prog_bar=True) iou = torch.zeros(self.out_channels) for i in range(self.out_channels): true_positive = confmat[i, i] false_positive = confmat.sum(dim=0)[i] - true_positive false_negative = confmat.sum(dim=1)[i] - true_positive union = true_positive + false_positive + false_negative if union > 0: iou[i] = true_positive / union else: iou[i] = float('nan') self.log(f'{prefix}_iou_class={i}', iou[i].item(), prog_bar=True) self.conf_mat.reset() def on_train_epoch_end(self): self._common_on_epoch_end("train") def on_validation_epoch_end(self): self._common_on_epoch_end("val") def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=PetSegTrainConfig.LEARNING_RATE)