|
import os |
|
import argparse |
|
import numpy as np |
|
from tqdm import tqdm |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
from modules.lseg_module_zs import LSegModuleZS |
|
from additional_utils.models import LSeg_MultiEvalModule |
|
from fewshot_data.common.logger import Logger, AverageMeter |
|
from fewshot_data.common.vis import Visualizer |
|
from fewshot_data.common.evaluation import Evaluator |
|
from fewshot_data.common import utils |
|
from fewshot_data.data.dataset import FSSDataset |
|
|
|
|
|
class Options: |
|
def __init__(self): |
|
parser = argparse.ArgumentParser(description="PyTorch Segmentation") |
|
|
|
parser.add_argument( |
|
"--model", type=str, default="encnet", help="model name (default: encnet)" |
|
) |
|
parser.add_argument( |
|
"--backbone", |
|
type=str, |
|
default="resnet50", |
|
help="backbone name (default: resnet50)", |
|
) |
|
parser.add_argument( |
|
"--dataset", |
|
type=str, |
|
default="ade20k", |
|
help="dataset name (default: pascal12)", |
|
) |
|
parser.add_argument( |
|
"--workers", type=int, default=16, metavar="N", help="dataloader threads" |
|
) |
|
parser.add_argument( |
|
"--base-size", type=int, default=520, help="base image size" |
|
) |
|
parser.add_argument( |
|
"--crop-size", type=int, default=480, help="crop image size" |
|
) |
|
parser.add_argument( |
|
"--train-split", |
|
type=str, |
|
default="train", |
|
help="dataset train split (default: train)", |
|
) |
|
|
|
parser.add_argument( |
|
"--aux", action="store_true", default=False, help="Auxilary Loss" |
|
) |
|
parser.add_argument( |
|
"--se-loss", |
|
action="store_true", |
|
default=False, |
|
help="Semantic Encoding Loss SE-loss", |
|
) |
|
parser.add_argument( |
|
"--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)" |
|
) |
|
parser.add_argument( |
|
"--batch-size", |
|
type=int, |
|
default=16, |
|
metavar="N", |
|
help="input batch size for \ |
|
training (default: auto)", |
|
) |
|
parser.add_argument( |
|
"--test-batch-size", |
|
type=int, |
|
default=16, |
|
metavar="N", |
|
help="input batch size for \ |
|
testing (default: same as batch size)", |
|
) |
|
|
|
parser.add_argument( |
|
"--no-cuda", |
|
action="store_true", |
|
default=False, |
|
help="disables CUDA training", |
|
) |
|
parser.add_argument( |
|
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" |
|
) |
|
|
|
parser.add_argument( |
|
"--weights", type=str, default=None, help="checkpoint to test" |
|
) |
|
|
|
parser.add_argument( |
|
"--eval", action="store_true", default=False, help="evaluating mIoU" |
|
) |
|
|
|
parser.add_argument( |
|
"--acc-bn", |
|
action="store_true", |
|
default=False, |
|
help="Re-accumulate BN statistics", |
|
) |
|
parser.add_argument( |
|
"--test-val", |
|
action="store_true", |
|
default=False, |
|
help="generate masks on val set", |
|
) |
|
parser.add_argument( |
|
"--no-val", |
|
action="store_true", |
|
default=False, |
|
help="skip validation during training", |
|
) |
|
|
|
parser.add_argument( |
|
"--module", |
|
default='', |
|
help="select model definition", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--no-scaleinv", |
|
dest="scale_inv", |
|
default=True, |
|
action="store_false", |
|
help="turn off scaleinv layers", |
|
) |
|
|
|
parser.add_argument( |
|
"--widehead", default=False, action="store_true", help="wider output head" |
|
) |
|
|
|
parser.add_argument( |
|
"--widehead_hr", |
|
default=False, |
|
action="store_true", |
|
help="wider output head", |
|
) |
|
|
|
parser.add_argument( |
|
"--ignore_index", |
|
type=int, |
|
default=-1, |
|
help="numeric value of ignore label in gt", |
|
) |
|
|
|
parser.add_argument( |
|
"--jobname", |
|
type=str, |
|
default="default", |
|
help="select which dataset", |
|
) |
|
|
|
parser.add_argument( |
|
"--no-strict", |
|
dest="strict", |
|
default=True, |
|
action="store_false", |
|
help="no-strict copy the model", |
|
) |
|
|
|
parser.add_argument( |
|
"--use_pretrained", |
|
type=str, |
|
default="True", |
|
help="whether use the default model to intialize the model", |
|
) |
|
|
|
parser.add_argument( |
|
"--arch_option", |
|
type=int, |
|
default=0, |
|
help="which kind of architecture to be used", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
'--nshot', |
|
type=int, |
|
default=1 |
|
) |
|
parser.add_argument( |
|
'--fold', |
|
type=int, |
|
default=0, |
|
choices=[0, 1, 2, 3] |
|
) |
|
parser.add_argument( |
|
'--nworker', |
|
type=int, |
|
default=0 |
|
) |
|
parser.add_argument( |
|
'--bsz', |
|
type=int, |
|
default=1 |
|
) |
|
parser.add_argument( |
|
'--benchmark', |
|
type=str, |
|
default='pascal', |
|
choices=['pascal', 'coco', 'fss', 'c2p'] |
|
) |
|
parser.add_argument( |
|
'--datapath', |
|
type=str, |
|
default='fewshot_data/Datasets_HSN' |
|
) |
|
|
|
parser.add_argument( |
|
"--activation", |
|
choices=['relu', 'lrelu', 'tanh'], |
|
default="relu", |
|
help="use which activation to activate the block", |
|
) |
|
|
|
|
|
self.parser = parser |
|
|
|
def parse(self): |
|
args = self.parser.parse_args() |
|
args.cuda = not args.no_cuda and torch.cuda.is_available() |
|
print(args) |
|
return args |
|
|
|
|
|
def test(args): |
|
module_def = LSegModuleZS |
|
|
|
module = module_def.load_from_checkpoint( |
|
checkpoint_path=args.weights, |
|
data_path=args.datapath, |
|
dataset=args.dataset, |
|
backbone=args.backbone, |
|
aux=args.aux, |
|
num_features=256, |
|
aux_weight=0, |
|
se_loss=False, |
|
se_weight=0, |
|
base_lr=0, |
|
batch_size=1, |
|
max_epochs=0, |
|
ignore_index=args.ignore_index, |
|
dropout=0.0, |
|
scale_inv=args.scale_inv, |
|
augment=False, |
|
no_batchnorm=False, |
|
widehead=args.widehead, |
|
widehead_hr=args.widehead_hr, |
|
map_locatin="cpu", |
|
arch_option=args.arch_option, |
|
use_pretrained=args.use_pretrained, |
|
strict=args.strict, |
|
logpath='fewshot/logpath_4T/', |
|
fold=args.fold, |
|
block_depth=0, |
|
nshot=args.nshot, |
|
finetune_mode=False, |
|
activation=args.activation, |
|
) |
|
|
|
Evaluator.initialize() |
|
if args.backbone in ["clip_resnet101"]: |
|
FSSDataset.initialize(img_size=480, datapath=args.datapath, use_original_imgsize=False, imagenet_norm=True) |
|
else: |
|
FSSDataset.initialize(img_size=480, datapath=args.datapath, use_original_imgsize=False) |
|
|
|
args.benchmark = args.dataset |
|
dataloader = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot) |
|
|
|
model = module.net.eval().cuda() |
|
|
|
|
|
print(model) |
|
|
|
scales = ( |
|
[0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] |
|
if args.dataset == "citys" |
|
else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] |
|
) |
|
|
|
f = open("logs/fewshot/log_fewshot-test_nshot{}_{}.txt".format(args.nshot, args.dataset), "a+") |
|
|
|
utils.fix_randseed(0) |
|
average_meter = AverageMeter(dataloader.dataset) |
|
for idx, batch in enumerate(dataloader): |
|
batch = utils.to_cuda(batch) |
|
image = batch['query_img'] |
|
target = batch['query_mask'] |
|
class_info = batch['class_id'] |
|
|
|
pred_mask = model(image, class_info) |
|
|
|
|
|
if args.benchmark == 'pascal' and batch['query_ignore_idx'] is not None: |
|
query_ignore_idx = batch['query_ignore_idx'] |
|
area_inter, area_union = Evaluator.classify_prediction(pred_mask.argmax(dim=1), target, query_ignore_idx) |
|
else: |
|
area_inter, area_union = Evaluator.classify_prediction(pred_mask.argmax(dim=1), target) |
|
|
|
average_meter.update(area_inter, area_union, class_info, loss=None) |
|
average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1) |
|
|
|
|
|
average_meter.write_result('Test', 0) |
|
test_miou, test_fb_iou = average_meter.compute_iou() |
|
|
|
Logger.info('Fold %d, %d-shot ==> mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, args.nshot, test_miou.item(), test_fb_iou.item())) |
|
Logger.info('==================== Finished Testing ====================') |
|
f.write('{}\n'.format(args.weights)) |
|
f.write('Fold %d, %d-shot ==> mIoU: %5.2f \t FB-IoU: %5.2f\n' % (args.fold, args.nshot, test_miou.item(), test_fb_iou.item())) |
|
f.close() |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
args = Options().parse() |
|
torch.manual_seed(args.seed) |
|
test(args) |
|
|