LSM / submodules /lang_seg /modules /lsegmentation_module.py
kairunwen's picture
Update Code
57746f1
import types
import time
import random
import clip
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from argparse import ArgumentParser
import pytorch_lightning as pl
from data import get_dataset, get_available_datasets
from encoding.models import get_segmentation_model
from encoding.nn import SegmentationLosses
from encoding.utils import batch_pix_accuracy, batch_intersection_union
# add mixed precision
import torch.cuda.amp as amp
import numpy as np
from encoding.utils import SegmentationMetric
class LSegmentationModule(pl.LightningModule):
def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs):
super().__init__()
self.data_path = data_path
self.batch_size = batch_size
self.base_lr = base_lr / 16 * batch_size
self.lr = self.base_lr
self.epochs = max_epochs
self.other_kwargs = kwargs
self.enabled = False #True mixed precision will make things complicated and leading to NAN error
self.scaler = amp.GradScaler(enabled=self.enabled)
def forward(self, x):
return self.net(x)
def evaluate(self, x, target=None):
pred = self.net.forward(x)
if isinstance(pred, (tuple, list)):
pred = pred[0]
if target is None:
return pred
correct, labeled = batch_pix_accuracy(pred.data, target.data)
inter, union = batch_intersection_union(pred.data, target.data, self.nclass)
return correct, labeled, inter, union
def evaluate_random(self, x, labelset, target=None):
pred = self.net.forward(x, labelset)
if isinstance(pred, (tuple, list)):
pred = pred[0]
if target is None:
return pred
correct, labeled = batch_pix_accuracy(pred.data, target.data)
inter, union = batch_intersection_union(pred.data, target.data, self.nclass)
return correct, labeled, inter, union
def training_step(self, batch, batch_nb):
img, target = batch
with amp.autocast(enabled=self.enabled):
out = self(img)
multi_loss = isinstance(out, tuple)
if multi_loss:
loss = self.criterion(*out, target)
else:
loss = self.criterion(out, target)
loss = self.scaler.scale(loss)
final_output = out[0] if multi_loss else out
train_pred, train_gt = self._filter_invalid(final_output, target)
if train_gt.nelement() != 0:
self.train_accuracy(train_pred, train_gt)
self.log("train_loss", loss)
return loss
def training_epoch_end(self, outs):
self.log("train_acc_epoch", self.train_accuracy.compute())
def validation_step(self, batch, batch_nb):
img, target = batch
out = self(img)
multi_loss = isinstance(out, tuple)
if multi_loss:
val_loss = self.criterion(*out, target)
else:
val_loss = self.criterion(out, target)
final_output = out[0] if multi_loss else out
valid_pred, valid_gt = self._filter_invalid(final_output, target)
self.val_iou.update(target, final_output)
pixAcc, iou = self.val_iou.get()
self.log("val_loss_step", val_loss)
self.log("pix_acc_step", pixAcc)
self.log(
"val_acc_step",
self.val_accuracy(valid_pred, valid_gt),
)
self.log("val_iou", iou)
def validation_epoch_end(self, outs):
pixAcc, iou = self.val_iou.get()
self.log("val_acc_epoch", self.val_accuracy.compute())
self.log("val_iou_epoch", iou)
self.log("pix_acc_epoch", pixAcc)
self.val_iou.reset()
def _filter_invalid(self, pred, target):
valid = target != self.other_kwargs["ignore_index"]
_, mx = torch.max(pred, dim=1)
return mx[valid], target[valid]
def configure_optimizers(self):
params_list = [
{"params": self.net.pretrained.parameters(), "lr": self.base_lr},
]
if hasattr(self.net, "scratch"):
print("Found output scratch")
params_list.append(
{"params": self.net.scratch.parameters(), "lr": self.base_lr * 10}
)
if hasattr(self.net, "auxlayer"):
print("Found auxlayer")
params_list.append(
{"params": self.net.auxlayer.parameters(), "lr": self.base_lr * 10}
)
if hasattr(self.net, "scale_inv_conv"):
print(self.net.scale_inv_conv)
print("Found scaleinv layers")
params_list.append(
{
"params": self.net.scale_inv_conv.parameters(),
"lr": self.base_lr * 10,
}
)
params_list.append(
{"params": self.net.scale2_conv.parameters(), "lr": self.base_lr * 10}
)
params_list.append(
{"params": self.net.scale3_conv.parameters(), "lr": self.base_lr * 10}
)
params_list.append(
{"params": self.net.scale4_conv.parameters(), "lr": self.base_lr * 10}
)
if self.other_kwargs["midasproto"]:
print("Using midas optimization protocol")
opt = torch.optim.Adam(
params_list,
lr=self.base_lr,
betas=(0.9, 0.999),
weight_decay=self.other_kwargs["weight_decay"],
)
sch = torch.optim.lr_scheduler.LambdaLR(
opt, lambda x: pow(1.0 - x / self.epochs, 0.9)
)
else:
opt = torch.optim.SGD(
params_list,
lr=self.base_lr,
momentum=0.9,
weight_decay=self.other_kwargs["weight_decay"],
)
sch = torch.optim.lr_scheduler.LambdaLR(
opt, lambda x: pow(1.0 - x / self.epochs, 0.9)
)
return [opt], [sch]
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.trainset,
batch_size=self.batch_size,
shuffle=True,
num_workers=16,
worker_init_fn=lambda x: random.seed(time.time() + x),
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self.valset,
batch_size=self.batch_size,
shuffle=False,
num_workers=16,
)
def get_trainset(self, dset, augment=False, **kwargs):
print(kwargs)
if augment == True:
mode = "train_x"
else:
mode = "train"
print(mode)
dset = get_dataset(
dset,
root=self.data_path,
split="train",
mode=mode,
transform=self.train_transform,
**kwargs
)
self.num_classes = dset.num_class
self.train_accuracy = pl.metrics.Accuracy()
return dset
def get_valset(self, dset, augment=False, **kwargs):
self.val_accuracy = pl.metrics.Accuracy()
self.val_iou = SegmentationMetric(self.num_classes)
if augment == True:
mode = "val_x"
else:
mode = "val"
print(mode)
return get_dataset(
dset,
root=self.data_path,
split="val",
mode=mode,
transform=self.val_transform,
**kwargs
)
def get_criterion(self, **kwargs):
return SegmentationLosses(
se_loss=kwargs["se_loss"],
aux=kwargs["aux"],
nclass=self.num_classes,
se_weight=kwargs["se_weight"],
aux_weight=kwargs["aux_weight"],
ignore_index=kwargs["ignore_index"],
)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument(
"--data_path", type=str, help="path where dataset is stored"
)
parser.add_argument(
"--dataset",
choices=get_available_datasets(),
default="ade20k",
help="dataset to train on",
)
parser.add_argument(
"--batch_size", type=int, default=16, help="size of the batches"
)
parser.add_argument(
"--base_lr", type=float, default=0.004, help="learning rate"
)
parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum")
parser.add_argument(
"--weight_decay", type=float, default=1e-4, help="weight_decay"
)
parser.add_argument(
"--aux", action="store_true", default=False, help="Auxilary Loss"
)
parser.add_argument(
"--aux-weight",
type=float,
default=0.2,
help="Auxilary loss weight (default: 0.2)",
)
parser.add_argument(
"--se-loss",
action="store_true",
default=False,
help="Semantic Encoding Loss SE-loss",
)
parser.add_argument(
"--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)"
)
parser.add_argument(
"--midasproto", action="store_true", default=False, help="midasprotocol"
)
parser.add_argument(
"--ignore_index",
type=int,
default=-1,
help="numeric value of ignore label in gt",
)
parser.add_argument(
"--augment",
action="store_true",
default=False,
help="Use extended augmentations",
)
return parser