File size: 3,906 Bytes
e8f2571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
from argparse import ArgumentParser

import mmcv
import requests
import torch
from mmengine.structures import InstanceData

from mmdet.apis import inference_detector, init_detector
from mmdet.registry import VISUALIZERS
from mmdet.structures import DetDataSample


def parse_args():
    parser = ArgumentParser()
    parser.add_argument('img', help='Image file')
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument('model_name', help='The model name in the server')
    parser.add_argument(
        '--inference-addr',
        default='127.0.0.1:8080',
        help='Address and port of the inference server')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--score-thr', type=float, default=0.5, help='bbox score threshold')
    parser.add_argument(
        '--work-dir',
        type=str,
        default=None,
        help='output directory to save drawn results.')
    args = parser.parse_args()
    return args


def align_ts_output(inputs, metainfo, device):
    bboxes = []
    labels = []
    scores = []
    for i, pred in enumerate(inputs):
        bboxes.append(pred['bbox'])
        labels.append(pred['class_label'])
        scores.append(pred['score'])
    pred_instances = InstanceData(metainfo=metainfo)
    pred_instances.bboxes = torch.tensor(
        bboxes, dtype=torch.float32, device=device)
    pred_instances.labels = torch.tensor(
        labels, dtype=torch.int64, device=device)
    pred_instances.scores = torch.tensor(
        scores, dtype=torch.float32, device=device)
    ts_data_sample = DetDataSample(pred_instances=pred_instances)
    return ts_data_sample


def main(args):
    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)
    # test a single image
    pytorch_results = inference_detector(model, args.img)
    keep = pytorch_results.pred_instances.scores >= args.score_thr
    pytorch_results.pred_instances = pytorch_results.pred_instances[keep]

    # init visualizer
    visualizer = VISUALIZERS.build(model.cfg.visualizer)
    # the dataset_meta is loaded from the checkpoint and
    # then pass to the model in init_detector
    visualizer.dataset_meta = model.dataset_meta

    # show the results
    img = mmcv.imread(args.img)
    img = mmcv.imconvert(img, 'bgr', 'rgb')
    pt_out_file = None
    ts_out_file = None
    if args.work_dir is not None:
        os.makedirs(args.work_dir, exist_ok=True)
        pt_out_file = os.path.join(args.work_dir, 'pytorch_result.png')
        ts_out_file = os.path.join(args.work_dir, 'torchserve_result.png')
    visualizer.add_datasample(
        'pytorch_result',
        img.copy(),
        data_sample=pytorch_results,
        draw_gt=False,
        out_file=pt_out_file,
        show=True,
        wait_time=0)

    url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
    with open(args.img, 'rb') as image:
        response = requests.post(url, image)
    metainfo = pytorch_results.pred_instances.metainfo
    ts_results = align_ts_output(response.json(), metainfo, args.device)

    visualizer.add_datasample(
        'torchserve_result',
        img,
        data_sample=ts_results,
        draw_gt=False,
        out_file=ts_out_file,
        show=True,
        wait_time=0)

    assert torch.allclose(pytorch_results.pred_instances.bboxes,
                          ts_results.pred_instances.bboxes)
    assert torch.allclose(pytorch_results.pred_instances.labels,
                          ts_results.pred_instances.labels)
    assert torch.allclose(pytorch_results.pred_instances.scores,
                          ts_results.pred_instances.scores)


if __name__ == '__main__':
    args = parse_args()
    main(args)