Spaces:
svjack
/
Runtime error

File size: 2,958 Bytes
17cd746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import torch
import inspect
import warnings
import torchvision
from .stylematte import StyleMatte

class StyleMatteEngine(torch.nn.Module):
    def __init__(self, device='cpu',human_matting_path='./pretrain_model/matting/stylematte_synth.pt'):
        super().__init__()
        self._device = device
        self.normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self._init_models(human_matting_path)

    def _init_models(self,_ckpt_path):
        # load dict
        state_dict = torch.load(_ckpt_path, map_location='cpu')
        # build model
        model = StyleMatte()
        model.load_state_dict(state_dict)
        self.model = model.to(self._device).eval()
    
    @torch.no_grad()
    def forward(self, input_image, return_type='matting', background_rgb=1.0):
        if not hasattr(self, 'model'):
            self._init_models()
        if input_image.max() > 2.0:
            warnings.warn('Image should be normalized to [0, 1].')
        _, ori_h, ori_w = input_image.shape
        input_image = input_image.to(self._device).float()
        image = input_image.clone()
        # resize
        if max(ori_h, ori_w) > 1024:
            scale = 1024.0 / max(ori_h, ori_w)
            resized_h, resized_w = int(ori_h * scale), int(ori_w * scale)
            image = torchvision.transforms.functional.resize(image, (resized_h, resized_w), antialias=True)
        else:
            resized_h, resized_w = ori_h, ori_w
        # padding
        if resized_h % 8 != 0 or resized_w % 8 != 0:
            image = torchvision.transforms.functional.pad(image, ((8-resized_w % 8)%8, (8-resized_h % 8)%8, 0, 0, ), padding_mode='reflect')
        # normalize and forwarding
        image = self.normalize(image)[None]
        predict = self.model(image)[0]
        # undo padding
        predict = predict[:, -resized_h:, -resized_w:]
        # undo resize
        if resized_h != ori_h or resized_w != ori_w:
            predict = torchvision.transforms.functional.resize(predict, (ori_h, ori_w), antialias=True)
        
        if return_type == 'alpha':
            return predict[0]
        elif return_type == 'matting':
            predict = predict.expand(3, -1, -1)
            matting_image = input_image.clone()
            background_rgb = matting_image.new_ones(matting_image.shape) * background_rgb
            matting_image = matting_image * predict + (1-predict) * background_rgb
            return matting_image, predict[0]
        elif return_type == 'all':
            predict = predict.expand(3, -1, -1)
            background_rgb = input_image.new_ones(input_image.shape) * background_rgb
            foreground_image = input_image * predict + (1-predict) * background_rgb
            background_image = input_image * (1-predict) + predict * background_rgb
            return foreground_image, background_image
        else:
            raise NotImplementedError