LSM / submodules /lang_seg /modules /lsegmentation_module_zs.py
kairunwen's picture
Update Code
57746f1
import types
import time
import random
import clip
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from argparse import ArgumentParser
import pytorch_lightning as pl
from encoding.models import get_segmentation_model
from encoding.nn import SegmentationLosses
from encoding.utils import batch_pix_accuracy, batch_intersection_union
# add mixed precision
import torch.cuda.amp as amp
import numpy as np
from encoding.utils.metrics import SegmentationMetric
# get fewshot dataloader
from fewshot_data.model.hsnet import HypercorrSqueezeNetwork
from fewshot_data.common.logger import Logger, AverageMeter
from fewshot_data.common.evaluation import Evaluator
from fewshot_data.common import utils
from fewshot_data.data.dataset import FSSDataset
class Fewshot_args:
datapath = 'fewshot_data/Datasets_HSN'
benchmark = 'pascal'
logpath = ''
nworker = 8
bsz = 20
fold = 0
class LSegmentationModuleZS(pl.LightningModule):
def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs):
super().__init__()
self.batch_size = batch_size
self.base_lr = base_lr / 16 * batch_size
self.lr = self.base_lr
self.epochs = max_epochs
self.other_kwargs = kwargs
self.enabled = False #True mixed precision will make things complicated and leading to NAN error
self.scaler = amp.GradScaler(enabled=self.enabled)
# for whether fix the encoder or not
self.fixed_encoder = True if kwargs["use_pretrained"] in ['clip_fixed'] else False
# fewshot hyperparameters
self.cross_entropy_loss = nn.CrossEntropyLoss()
self.args = self.get_fewshot_args()
if data_path:
self.args.datapath = data_path
self.args.logpath = self.other_kwargs["logpath"]
self.args.benchmark = dataset
self.args.bsz = self.batch_size
self.args.fold = self.other_kwargs["fold"]
self.args.nshot = self.other_kwargs["nshot"]
self.args.finetune_mode = self.other_kwargs["finetune_mode"]
Logger.initialize(self.args, training=True)
Evaluator.initialize()
if kwargs["backbone"] in ["clip_resnet101"]:
FSSDataset.initialize(img_size=480, datapath=self.args.datapath, use_original_imgsize=False, imagenet_norm=True)
else:
FSSDataset.initialize(img_size=480, datapath=self.args.datapath, use_original_imgsize=False)
self.best_val_miou = float('-inf')
self.num_classes = 2
self.labels = ['others', '']
self.fewshot_trn_loss = 100
self.fewshot_trn_miou = 0
self.fewshot_trn_fb_iou = 0
def get_fewshot_args(self):
return Fewshot_args()
def forward(self, x, class_info):
return self.net(x, class_info)
def training_step(self, batch, batch_nb):
if self.args.finetune_mode:
if self.args.nshot == 5:
bshape = batch['support_imgs'].shape
img = batch['support_imgs'].view(-1, bshape[2], bshape[3], bshape[4])
target = batch['support_masks'].view(-1, bshape[3], bshape[4])
class_info = batch['class_id']
for i in range(1, 5):
class_info = torch.cat([class_info, batch['class_id']])
with amp.autocast(enabled=self.enabled):
out = self(img, class_info)
loss = self.criterion(out, target)
loss = self.scaler.scale(loss)
self.log("train_loss", loss)
# 3. Evaluate prediction
if self.args.benchmark == 'pascal' and batch['support_ignore_idxs'] is not None:
query_ignore_idx = batch['support_ignore_idxs'].view(-1, bshape[3], bshape[4])
area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target, query_ignore_idx)
else:
area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target)
else:
img = batch['support_imgs'].squeeze(1)
target = batch['support_masks'].squeeze(1)
class_info = batch['class_id']
with amp.autocast(enabled=self.enabled):
out = self(img, class_info)
loss = self.criterion(out, target)
loss = self.scaler.scale(loss)
self.log("train_loss", loss)
# 3. Evaluate prediction
if self.args.benchmark == 'pascal' and batch['support_ignore_idxs'] is not None:
query_ignore_idx = batch['support_ignore_idxs'].squeeze(1)
area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target, query_ignore_idx)
else:
area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target)
else:
img = torch.cat([batch['support_imgs'].squeeze(1), batch['query_img']], dim=0)
target = torch.cat([batch['support_masks'].squeeze(1), batch['query_mask']], dim=0)
class_info=torch.cat([batch['class_id'], batch['class_id']], dim=0)
with amp.autocast(enabled=self.enabled):
out = self(img, class_info)
loss = self.criterion(out, target)
loss = self.scaler.scale(loss)
self.log("train_loss", loss)
# 3. Evaluate prediction
if self.args.benchmark == 'pascal' and batch['query_ignore_idx'] is not None:
query_ignore_idx = torch.cat([batch['support_ignore_idxs'].squeeze(1), batch['query_ignore_idx']], dim=0)
area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target, query_ignore_idx)
else:
area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target)
self.train_average_meter.update(area_inter, area_union, class_info, loss.detach().clone())
if self.global_rank == 0:
return_value = self.train_average_meter.write_process(batch_nb, self.len_train_dataloader, self.current_epoch, write_batch_idx=50)
if return_value is not None:
iou, fb_iou = return_value
self.log("fewshot_train_iou", iou)
self.log("fewshot_trainl_fb_iou", fb_iou)
return loss
def training_epoch_end(self, outs):
if self.global_rank == 0:
self.train_average_meter.write_result('Training', self.current_epoch)
self.fewshot_trn_loss = utils.mean(self.train_average_meter.loss_buf)
self.fewshot_trn_miou, self.fewshot_trn_fb_iou = self.train_average_meter.compute_iou()
self.log("fewshot_trn_loss", self.fewshot_trn_loss)
self.log("fewshot_trn_miou", self.fewshot_trn_miou)
self.log("fewshot_trn_fb_iou", self.fewshot_trn_fb_iou)
def validation_step(self, batch, batch_nb):
if self.args.finetune_mode and self.args.nshot == 5:
bshape = batch['query_img'].shape
img = batch['query_img'].view(-1, bshape[2], bshape[3], bshape[4])
target = batch['query_mask'].view(-1, bshape[3], bshape[4])
class_info = batch['class_id']
for i in range(1, 5):
class_info = torch.cat([class_info, batch['class_id']])
out = self(img, class_info)
val_loss = self.criterion(out, target)
# 3. Evaluate prediction
if self.args.benchmark == 'pascal' and batch['query_ignore_idx'] is not None:
query_ignore_idx = batch['query_ignore_idx'].view(-1, bshape[3], bshape[4])
area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target, query_ignore_idx)
else:
area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target)
else:
img = batch['query_img'].squeeze(1)
target = batch['query_mask'].squeeze(1)
class_info = batch['class_id']
out = self(img, class_info)
val_loss = self.criterion(out, target)
# 3. Evaluate prediction
if self.args.benchmark == 'pascal' and batch['query_ignore_idx'] is not None:
query_ignore_idx = batch['query_ignore_idx'].squeeze(1)
area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target, query_ignore_idx)
else:
area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target)
self.val_average_meter.update(area_inter, area_union, class_info, val_loss.detach().clone())
if self.global_rank == 0:
return_value = self.val_average_meter.write_process(batch_nb, self.len_val_dataloader, self.current_epoch, write_batch_idx=50)
if return_value is not None:
iou, fb_iou = return_value
self.log("fewshot_val_iou", iou)
self.log("fewshot_val_fb_iou", fb_iou)
def validation_epoch_end(self, outs):
if self.global_rank == 0:
self.val_average_meter.write_result('Validation', self.current_epoch)
val_loss = utils.mean(self.val_average_meter.loss_buf)
val_miou, val_fb_iou = self.val_average_meter.compute_iou()
self.log("fewshot_val_loss", val_loss)
self.log("fewshot_val_miou", val_miou)
self.log("fewshot_val_fb_iou", val_fb_iou)
if self.global_rank == 0:
Logger.tbd_writer.add_scalars('fewshot_data/data/loss', {'trn_loss': self.fewshot_trn_loss, 'val_loss': val_loss}, self.current_epoch)
Logger.tbd_writer.add_scalars('fewshot_data/data/miou', {'trn_miou': self.fewshot_trn_miou, 'val_miou': val_miou}, self.current_epoch)
Logger.tbd_writer.add_scalars('fewshot_data/data/fb_iou', {'trn_fb_iou': self.fewshot_trn_fb_iou, 'val_fb_iou': val_fb_iou}, self.current_epoch)
Logger.tbd_writer.flush()
if self.current_epoch + 1 == self.epochs:
Logger.tbd_writer.close()
Logger.info('==================== Finished Training ====================')
threshold_epoch = 3
if self.args.benchmark in ['pascal', 'coco'] and self.current_epoch >= threshold_epoch:
print('End this loop!')
exit()
def configure_optimizers(self):
# if we want to fix the encoder
if self.fixed_encoder:
params_list = [
{"params": self.net.pretrained.model.parameters(), "lr": 0},
]
params_list.append(
{"params": self.net.pretrained.act_postprocess1.parameters(), "lr": self.base_lr}
)
params_list.append(
{"params": self.net.pretrained.act_postprocess2.parameters(), "lr": self.base_lr}
)
params_list.append(
{"params": self.net.pretrained.act_postprocess3.parameters(), "lr": self.base_lr}
)
params_list.append(
{"params": self.net.pretrained.act_postprocess4.parameters(), "lr": self.base_lr}
)
else:
params_list = [
{"params": self.net.pretrained.parameters(), "lr": self.base_lr},
]
if hasattr(self.net, "scratch"):
print("Found output scratch")
params_list.append(
{"params": self.net.scratch.parameters(), "lr": self.base_lr * 10}
)
if hasattr(self.net, "auxlayer"):
print("Found auxlayer")
params_list.append(
{"params": self.net.auxlayer.parameters(), "lr": self.base_lr * 10}
)
if hasattr(self.net, "scale_inv_conv"):
print(self.net.scale_inv_conv)
print("Found scaleinv layers")
params_list.append(
{
"params": self.net.scale_inv_conv.parameters(),
"lr": self.base_lr * 10,
}
)
params_list.append(
{"params": self.net.scale2_conv.parameters(), "lr": self.base_lr * 10}
)
params_list.append(
{"params": self.net.scale3_conv.parameters(), "lr": self.base_lr * 10}
)
params_list.append(
{"params": self.net.scale4_conv.parameters(), "lr": self.base_lr * 10}
)
if self.other_kwargs["midasproto"]:
print("Using midas optimization protocol")
opt = torch.optim.Adam(
params_list,
lr=self.base_lr,
betas=(0.9, 0.999),
weight_decay=self.other_kwargs["weight_decay"],
)
sch = torch.optim.lr_scheduler.LambdaLR(
opt, lambda x: pow(1.0 - x / self.epochs, 0.9)
)
else:
opt = torch.optim.SGD(
params_list,
lr=self.base_lr,
momentum=0.9,
weight_decay=self.other_kwargs["weight_decay"],
)
sch = torch.optim.lr_scheduler.LambdaLR(
opt, lambda x: pow(1.0 - x / self.epochs, 0.9)
)
return [opt], [sch]
def train_dataloader(self):
if self.args.finetune_mode:
dataloader = FSSDataset.build_dataloader(
self.args.benchmark,
self.args.bsz,
self.args.nworker,
self.args.fold,
'test',
self.args.nshot)
else:
dataloader = FSSDataset.build_dataloader(
self.args.benchmark,
self.args.bsz,
self.args.nworker,
self.args.fold,
'trn')
self.len_train_dataloader = len(dataloader) // torch.cuda.device_count()
self.train_average_meter = AverageMeter(dataloader.dataset)
return dataloader
def val_dataloader(self):
self.val_iou = SegmentationMetric(self.num_classes)
if self.args.finetune_mode:
dataloader = FSSDataset.build_dataloader(
self.args.benchmark,
self.args.bsz,
self.args.nworker,
self.args.fold,
'test',
self.args.nshot)
else:
dataloader = FSSDataset.build_dataloader(
self.args.benchmark,
self.args.bsz,
self.args.nworker,
self.args.fold,
'val')
self.len_val_dataloader = len(dataloader) // torch.cuda.device_count()
self.val_average_meter = AverageMeter(dataloader.dataset)
return dataloader
def criterion(self, logit_mask, gt_mask):
bsz = logit_mask.size(0)
logit_mask = logit_mask.view(bsz, 2, -1)
gt_mask = gt_mask.view(bsz, -1).long()
return self.cross_entropy_loss(logit_mask, gt_mask)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument(
"--data_path",
type=str,
default='',
help="path where dataset is stored"
)
parser.add_argument(
"--dataset",
type=str,
default='pascal',
choices=['pascal', 'coco', 'fss'],
)
parser.add_argument(
"--batch_size", type=int, default=20, help="size of the batches"
)
parser.add_argument(
"--base_lr", type=float, default=0.004, help="learning rate"
)
parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum")
parser.add_argument(
"--weight_decay", type=float, default=1e-4, help="weight_decay"
)
parser.add_argument(
"--aux", action="store_true", default=False, help="Auxilary Loss"
)
parser.add_argument(
"--aux-weight",
type=float,
default=0.2,
help="Auxilary loss weight (default: 0.2)",
)
parser.add_argument(
"--se-loss",
action="store_true",
default=False,
help="Semantic Encoding Loss SE-loss",
)
parser.add_argument(
"--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)"
)
parser.add_argument(
"--midasproto", action="store_true", default=False, help="midasprotocol"
)
parser.add_argument(
"--ignore_index",
type=int,
default=-1,
help="numeric value of ignore label in gt",
)
parser.add_argument(
"--augment",
action="store_true",
default=False,
help="Use extended augmentations",
)
parser.add_argument(
"--use_relabeled",
action="store_true",
default=False,
help="Use extended augmentations",
)
parser.add_argument(
"--nworker",
type=int,
default=8
)
parser.add_argument(
"--fold",
type=int,
default=0,
choices=[0, 1, 2, 3]
)
parser.add_argument(
"--logpath",
type=str,
default=''
)
parser.add_argument(
"--nshot",
type=int,
default=0 #1
)
parser.add_argument(
"--finetune_mode",
action="store_true",
default=False,
help="whether finetune or not"
)
return parser