Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from utils import CONFIG | |
from networks import m2ms, ops | |
import sys | |
sys.path.insert(0, './segment-anything') | |
from segment_anything import sam_model_registry | |
class sam_m2m(nn.Module): | |
def __init__(self, m2m): | |
super(sam_m2m, self).__init__() | |
if m2m not in m2ms.__all__: | |
raise NotImplementedError("Unknown M2M {}".format(m2m)) | |
self.m2m = m2ms.__dict__[m2m](nc=256) | |
self.seg_model = sam_model_registry['vit_b'](checkpoint=None) | |
self.seg_model.eval() | |
def forward(self, image, guidance): | |
self.seg_model.eval() | |
with torch.no_grad(): | |
feas, masks = self.seg_model.forward_m2m(image, guidance, multimask_output=True) | |
pred = self.m2m(feas, image, masks) | |
return pred | |
def forward_inference(self, image_dict): | |
self.seg_model.eval() | |
with torch.no_grad(): | |
feas, masks, post_masks = self.seg_model.forward_m2m_inference(image_dict, multimask_output=True) | |
pred = self.m2m(feas, image_dict["image"], masks) | |
return feas, pred, post_masks | |
def get_generator_m2m(seg, m2m): | |
if seg == 'sam': | |
generator = sam_m2m(m2m=m2m) | |
return generator |