SAM / app.py
YoungMeezz's picture
Create app.py
1a44432 verified
raw
history blame contribute delete
39.4 kB
# --------------------------------------------------------
# PersonalizeSAM -- Personalize Segment Anything Model with One Shot
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from PIL import Image
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
from torch.nn import functional as F
from show import *
from per_segment_anything import sam_model_registry, SamPredictor
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image, ImageDraw
from PIL import ImageDraw, ImageFont
class ImageMask(gr.components.Image):
"""
Sets: source="canvas", tool="sketch"
"""
is_template = True
def __init__(self, **kwargs):
super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
def preprocess(self, x):
return super().preprocess(x)
class Mask_Weights(nn.Module):
def __init__(self):
super().__init__()
self.weights = nn.Parameter(torch.ones(2, 1, requires_grad=True) / 3)
def point_selection(mask_sim, topk=1):
# Top-1 point selection
w, h = mask_sim.shape
topk_xy = mask_sim.flatten(0).topk(topk)[1]
topk_x = (topk_xy // h).unsqueeze(0)
topk_y = (topk_xy - topk_x * h)
topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0)
topk_label = np.array([1] * topk)
topk_xy = topk_xy.cpu().numpy()
# Top-last point selection
last_xy = mask_sim.flatten(0).topk(topk, largest=False)[1]
last_x = (last_xy // h).unsqueeze(0)
last_y = (last_xy - last_x * h)
last_xy = torch.cat((last_y, last_x), dim=0).permute(1, 0)
last_label = np.array([0] * topk)
last_xy = last_xy.cpu().numpy()
return topk_xy, topk_label, last_xy, last_label
def calculate_dice_loss(inputs, targets, num_masks = 1):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
"""
inputs = inputs.sigmoid()
inputs = inputs.flatten(1)
numerator = 2 * (inputs * targets).sum(-1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_masks
def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float = 0.25, gamma: float = 2):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss.mean(1).sum() / num_masks
def inference(ic_image, ic_mask, image1, image2):
# in context image and mask
ic_image = np.array(ic_image.convert("RGB"))
ic_mask = np.array(ic_mask.convert("RGB"))
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to('cpu')
# sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
predictor = SamPredictor(sam)
# Image features encoding
ref_mask = predictor.set_image(ic_image, ic_mask)
ref_feat = predictor.features.squeeze().permute(1, 2, 0)
ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
ref_mask = ref_mask.squeeze()[0]
# Target feature extraction
print("======> Obtain Location Prior" )
target_feat = ref_feat[ref_mask > 0]
target_embedding = target_feat.mean(0).unsqueeze(0)
target_feat = target_embedding / target_embedding.norm(dim=-1, keepdim=True)
target_embedding = target_embedding.unsqueeze(0)
output_image = []
for test_image in [image1, image2]:
print("======> Testing Image" )
test_image = np.array(test_image.convert("RGB"))
# Image feature encoding
predictor.set_image(test_image)
test_feat = predictor.features.squeeze()
# Cosine similarity
C, h, w = test_feat.shape
test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
test_feat = test_feat.reshape(C, h * w)
sim = target_feat @ test_feat
sim = sim.reshape(1, 1, h, w)
sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
sim = predictor.model.postprocess_masks(
sim,
input_size=predictor.input_size,
original_size=predictor.original_size).squeeze()
# Positive-negative location prior
topk_xy_i, topk_label_i, last_xy_i, last_label_i = point_selection(sim, topk=1)
topk_xy = np.concatenate([topk_xy_i, last_xy_i], axis=0)
topk_label = np.concatenate([topk_label_i, last_label_i], axis=0)
# Obtain the target guidance for cross-attention layers
sim = (sim - sim.mean()) / torch.std(sim)
sim = F.interpolate(sim.unsqueeze(0).unsqueeze(0), size=(64, 64), mode="bilinear")
attn_sim = sim.sigmoid_().unsqueeze(0).flatten(3)
# First-step prediction
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
multimask_output=False,
attn_sim=attn_sim, # Target-guided Attention
target_embedding=target_embedding # Target-semantic Prompting
)
best_idx = 0
# Cascaded Post-refinement-1
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
mask_input=logits[best_idx: best_idx + 1, :, :],
multimask_output=True)
best_idx = np.argmax(scores)
# Cascaded Post-refinement-2
y, x = np.nonzero(masks[best_idx])
x_min = x.min()
x_max = x.max()
y_min = y.min()
y_max = y.max()
input_box = np.array([x_min, y_min, x_max, y_max])
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
box=input_box[None, :],
mask_input=logits[best_idx: best_idx + 1, :, :],
multimask_output=True)
best_idx = np.argmax(scores)
final_mask = masks[best_idx]
mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
mask_colors[final_mask, :] = np.array([[128, 0, 0]])
output_image.append(Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB'))
return output_image[0].resize((224, 224)), output_image[1].resize((224, 224))
def inference_scribble(image, image1, image2):
# in context image and mask
ic_image = image["image"]
ic_mask = image["mask"]
ic_image = np.array(ic_image.convert("RGB"))
ic_mask = np.array(ic_mask.convert("RGB"))
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to('cpu')
# sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
predictor = SamPredictor(sam)
# Image features encoding
ref_mask = predictor.set_image(ic_image, ic_mask)
ref_feat = predictor.features.squeeze().permute(1, 2, 0)
ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
ref_mask = ref_mask.squeeze()[0]
# Target feature extraction
print("======> Obtain Location Prior" )
target_feat = ref_feat[ref_mask > 0]
target_embedding = target_feat.mean(0).unsqueeze(0)
target_feat = target_embedding / target_embedding.norm(dim=-1, keepdim=True)
target_embedding = target_embedding.unsqueeze(0)
output_image = []
for test_image in [image1, image2]:
print("======> Testing Image" )
test_image = np.array(test_image.convert("RGB"))
# Image feature encoding
predictor.set_image(test_image)
test_feat = predictor.features.squeeze()
# Cosine similarity
C, h, w = test_feat.shape
test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
test_feat = test_feat.reshape(C, h * w)
sim = target_feat @ test_feat
sim = sim.reshape(1, 1, h, w)
sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
sim = predictor.model.postprocess_masks(
sim,
input_size=predictor.input_size,
original_size=predictor.original_size).squeeze()
# Positive-negative location prior
topk_xy_i, topk_label_i, last_xy_i, last_label_i = point_selection(sim, topk=1)
topk_xy = np.concatenate([topk_xy_i, last_xy_i], axis=0)
topk_label = np.concatenate([topk_label_i, last_label_i], axis=0)
# Obtain the target guidance for cross-attention layers
sim = (sim - sim.mean()) / torch.std(sim)
sim = F.interpolate(sim.unsqueeze(0).unsqueeze(0), size=(64, 64), mode="bilinear")
attn_sim = sim.sigmoid_().unsqueeze(0).flatten(3)
# First-step prediction
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
multimask_output=False,
attn_sim=attn_sim, # Target-guided Attention
target_embedding=target_embedding # Target-semantic Prompting
)
best_idx = 0
# Cascaded Post-refinement-1
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
mask_input=logits[best_idx: best_idx + 1, :, :],
multimask_output=True)
best_idx = np.argmax(scores)
# Cascaded Post-refinement-2
y, x = np.nonzero(masks[best_idx])
x_min = x.min()
x_max = x.max()
y_min = y.min()
y_max = y.max()
input_box = np.array([x_min, y_min, x_max, y_max])
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
box=input_box[None, :],
mask_input=logits[best_idx: best_idx + 1, :, :],
multimask_output=True)
best_idx = np.argmax(scores)
final_mask = masks[best_idx]
mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
mask_colors[final_mask, :] = np.array([[128, 0, 0]])
output_image.append(Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB'))
return output_image[0].resize((224, 224)), output_image[1].resize((224, 224))
def inference_finetune_train(ic_image, ic_mask, image1, image2):
# in context image and mask
ic_image = np.array(ic_image.convert("RGB"))
ic_mask = np.array(ic_mask.convert("RGB"))
gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
gt_mask = gt_mask.float().unsqueeze(0).flatten(1).to('cpu')
# gt_mask = gt_mask.float().unsqueeze(0).flatten(1)
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to('cpu')
# sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
for name, param in sam.named_parameters():
param.requires_grad = False
predictor = SamPredictor(sam)
#자기 위치 우선값 획득
print("======> Obtain Self Location Prior" )
# Image features encoding
ref_mask = predictor.set_image(ic_image, ic_mask)
ref_feat = predictor.features.squeeze().permute(1, 2, 0)
ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
ref_mask = ref_mask.squeeze()[0]
# Target feature extraction
target_feat = ref_feat[ref_mask > 0]
target_feat_mean = target_feat.mean(0)
target_feat_max = torch.max(target_feat, dim=0)[0]
target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0)
# Cosine similarity
h, w, C = ref_feat.shape
target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True)
ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True)
ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w)
sim = target_feat @ ref_feat
# target_feat 저장
torch.save(target_feat, 'target_feat.pth')
print("target_feat가 'target_feat.pth' 파일로 저장되었습니다.")
sim = sim.reshape(1, 1, h, w)
sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
sim = predictor.model.postprocess_masks(
sim,
input_size=predictor.input_size,
original_size=predictor.original_size).squeeze()
# Positive location prior
topk_xy, topk_label, _, _ = point_selection(sim, topk=1)
print('======> Start Training')
# Learnable mask weights
mask_weights = Mask_Weights().to('cpu')
# mask_weights = Mask_Weights()
mask_weights.train()
train_epoch = 1000
optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=1e-4, eps=1e-4, betas=(0.9, 0.999), weight_decay=0.01, amsgrad=False)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_epoch)
for train_idx in range(train_epoch):
# Run the decoder
masks, scores, logits, logits_high = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
multimask_output=True)
logits_high = logits_high.flatten(1)
# Weighted sum three-scale masks
weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
logits_high = logits_high * weights
logits_high = logits_high.sum(0).unsqueeze(0)
dice_loss = calculate_dice_loss(logits_high, gt_mask)
focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask)
loss = dice_loss + focal_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
if train_idx % 10 == 0:
print('Train Epoch: {:} / {:}'.format(train_idx, train_epoch))
current_lr = scheduler.get_last_lr()[0]
print('LR: {:.6f}, Dice_Loss: {:.4f}, Focal_Loss: {:.4f}'.format(current_lr, dice_loss.item(), focal_loss.item()))
mask_weights.eval()
weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
weights_np = weights.detach().cpu().numpy()
print('======> Mask weights:\n', weights_np)
# # 1. 가중치 저장
torch.save(mask_weights.state_dict(), 'mask_weights.pth')
print("가중치가 'mask_weights.pth' 파일로 저장되었습니다.")
#########################Training 끝 ########################################
# 2. 테스트 전용 코드
# 모델 초기화 및 가중치 로드
mask_weights = Mask_Weights().to('cpu')
mask_weights.load_state_dict(torch.load('Personalize-SAM\mask_weights.pth'))
mask_weights.eval() # 평가 모드로 설정 (추가 학습 방지)
weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
weights_np = weights.detach().cpu().numpy()
print('======> Mask weights:\n', weights_np)
print('======> Start Testing')
output_image = []
for test_image in [image1, image2]:
test_image = np.array(test_image.convert("RGB"))
# Image feature encoding
predictor.set_image(test_image)
test_feat = predictor.features.squeeze()
# Image feature encoding
predictor.set_image(test_image)
test_feat = predictor.features.squeeze()
# Cosine similarity
C, h, w = test_feat.shape
test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
test_feat = test_feat.reshape(C, h * w)
sim = target_feat @ test_feat
sim = sim.reshape(1, 1, h, w)
sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
sim = predictor.model.postprocess_masks(
sim,
input_size=predictor.input_size,
original_size=predictor.original_size).squeeze()
# Positive location prior 양성 위치 우선값
topk_xy, topk_label, _, _ = point_selection(sim, topk=1)
print("좌표값",topk_xy)
# First-step prediction
masks, scores, logits, logits_high = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
multimask_output=True)
# 예측 점수 출력
# print("예측 점수 (scores):")
# for idx, score in enumerate(scores):
# print(f"Mask {idx + 1}: {score.item():.4f}")
# Weighted sum three-scale masks 세 가지 스케일의 마스크를 가중치 합산하는 과정
logits_high = logits_high * weights.unsqueeze(-1)
logit_high = logits_high.sum(0)
mask = (logit_high > 0).detach().cpu().numpy()
logits = logits * weights_np[..., None]
logit = logits.sum(0)
# Cascaded Post-refinement-1 모델의 세분화된 후처리 단계 중 첫 번째 단계
y, x = np.nonzero(mask)
x_min = x.min()
x_max = x.max()
y_min = y.min()
y_max = y.max()
input_box = np.array([x_min, y_min, x_max, y_max])
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
box=input_box[None, :],
mask_input=logit[None, :, :],
multimask_output=True)
best_idx = np.argmax(scores)
# Cascaded Post-refinement-2 모델의 세분화된 후처리 단계 중 두 번째 단계
y, x = np.nonzero(masks[best_idx])
x_min = x.min()
x_max = x.max()
y_min = y.min()
y_max = y.max()
input_box = np.array([x_min, y_min, x_max, y_max])
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
box=input_box[None, :],
mask_input=logits[best_idx: best_idx + 1, :, :],
multimask_output=True)
best_idx = np.argmax(scores)
final_mask = masks[best_idx]
# 예측 점수 출력
print("예측 점수 (scores):")
for idx, score in enumerate(scores):
print(f"Mask {idx + 1}: {score.item():.4f}")
# Final mask의 좌표 추출
# y_coords, x_coords = np.nonzero(final_mask)
# # 좌표를 (y, x) 형식으로 묶어서 출력
# coordinates = list(zip(y_coords, x_coords))
# # 좌표 출력
# print("Segmentation된 좌표들:")
# for coord in coordinates:
# print(coord)
# Image 생성 및 점수 표시
output_img = Image.fromarray((test_image).astype('uint8'), 'RGB')
draw = ImageDraw.Draw(output_img)
# 신뢰도 점수를 마스크 영역 위에 표시
for idx, (mask, score) in enumerate(zip(masks, scores)):
y, x = np.nonzero(mask)
if len(x) > 0 and len(y) > 0: # 마스크가 비어있지 않을 때만 텍스트 표시
x_center = int(x.mean())
y_center = int(y.mean())
draw.text((x_center, y_center), f"{score.item():.2f}", fill=(255, 255, 0))
# 최종 마스크 및 점수가 포함된 이미지를 리스트에 추가
mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
mask_colors[final_mask, :] = np.array([[128, 0, 0]])
overlay_image = Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB')
draw_overlay = ImageDraw.Draw(overlay_image)
for idx, score in enumerate(scores):
draw_overlay.text((10, 10 + 20 * idx), f"Mask {idx + 1}: {score.item():.2f}", fill=(255, 255, 0))
output_image.append(overlay_image)
# output_image.append(Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB'))
return output_image[0].resize((224, 224)), output_image[1].resize((224, 224))
# 컨투어와 바운딩 박스를 그리는 함수
def draw_contours_and_bboxes(image, mask):
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 객체 수 계산
object_count = len(contours)
# 이미지에 컨투어와 바운딩 박스를 그리기
for contour in contours:
# 바운딩 박스
x, y, w, h = cv2.boundingRect(contour)
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 0), 2) # 초록색 바운딩 박스
# 컨투어 그리기
cv2.drawContours(image, [contour], -1, (0, 0, 255), 2) # 빨간색 컨투어
return image, object_count
def inference_finetune_test(image1, image2, image3, image4):
# in context image and mask
# ic_image = np.array(ic_image.convert("RGB"))
# ic_mask = np.array(ic_mask.convert("RGB"))
# gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
# gt_mask = gt_mask.float().unsqueeze(0).flatten(1).to('cpu')
# # gt_mask = gt_mask.float().unsqueeze(0).flatten(1)
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to('cpu')
# # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
# for name, param in sam.named_parameters():
# param.requires_grad = False
predictor = SamPredictor(sam)
# #자기 위치 우선값 획득
print("======> Obtain Self Location Prior" )
# Image features encoding
# ref_mask = predictor.set_image(ic_image, ic_mask)
# ref_feat = predictor.features.squeeze().permute(1, 2, 0)
# ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
# ref_mask = ref_mask.squeeze()[0]
# # Target feature extraction
# target_feat = ref_feat[ref_mask > 0]
# target_feat_mean = target_feat.mean(0)
# target_feat_max = torch.max(target_feat, dim=0)[0]
# target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0)
# # Cosine similarity
# h, w, C = ref_feat.shape
# target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True)
# ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True)
# ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w)
# sim = target_feat @ ref_feat
# sim = sim.reshape(1, 1, h, w)
# sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
# sim = predictor.model.postprocess_masks(
# sim,
# input_size=predictor.input_size,
# original_size=predictor.original_size).squeeze()
# # Positive location prior
# topk_xy, topk_label, _, _ = point_selection(sim, topk=1)
# print('======> Start Training')
# # Learnable mask weights
# mask_weights = Mask_Weights().to('cpu')
# # mask_weights = Mask_Weights()
# mask_weights.train()
# train_epoch = 1000
# optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=1e-4, eps=1e-4, betas=(0.9, 0.999), weight_decay=0.01, amsgrad=False)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_epoch)
# for train_idx in range(train_epoch):
# # Run the decoder
# masks, scores, logits, logits_high = predictor.predict(
# point_coords=topk_xy,
# point_labels=topk_label,
# multimask_output=True)
# logits_high = logits_high.flatten(1)
# # Weighted sum three-scale masks
# weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
# logits_high = logits_high * weights
# logits_high = logits_high.sum(0).unsqueeze(0)
# dice_loss = calculate_dice_loss(logits_high, gt_mask)
# focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask)
# loss = dice_loss + focal_loss
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# scheduler.step()
# if train_idx % 10 == 0:
# print('Train Epoch: {:} / {:}'.format(train_idx, train_epoch))
# current_lr = scheduler.get_last_lr()[0]
# print('LR: {:.6f}, Dice_Loss: {:.4f}, Focal_Loss: {:.4f}'.format(current_lr, dice_loss.item(), focal_loss.item()))
# mask_weights.eval()
# weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
# weights_np = weights.detach().cpu().numpy()
# print('======> Mask weights:\n', weights_np)
# # 1. 가중치 저장
# torch.save(mask_weights.state_dict(), 'mask_weights.pth')
# print("가중치가 'mask_weights.pth' 파일로 저장되었습니다.")
#########################Training 끝 ########################################
# 2. 테스트 전용 코드
# 모델 초기화 및 가중치 로드
mask_weights = Mask_Weights().to('cpu')
mask_weights.load_state_dict(torch.load('Personalize-SAM\mask_weights.pth'))
mask_weights.eval() # 평가 모드로 설정 (추가 학습 방지)
weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
weights_np = weights.detach().cpu().numpy()
print('======> Mask weights:\n', weights_np)
print('======> Start Testing')
output_image = []
# SAM Segmentation 결과를 저장할 dictionary
segmentation_results = []
for test_image in [image1, image2, image3, image4]:
test_image = np.array(test_image.convert("RGB"))
# Image feature encoding
predictor.set_image(test_image)
test_feat = predictor.features.squeeze()
# Image feature encoding
predictor.set_image(test_image)
test_feat = predictor.features.squeeze()
# Cosine similarity
C, h, w = test_feat.shape
test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
test_feat = test_feat.reshape(C, h * w)
# target_feat 불러오기
target_feat = torch.load('Personalize-SAM\\target_feat.pth')
sim = target_feat @ test_feat
sim = sim.reshape(1, 1, h, w)
sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
sim = predictor.model.postprocess_masks(
sim,
input_size=predictor.input_size,
original_size=predictor.original_size).squeeze()
# Positive location prior 양성 위치 우선값
topk_xy, topk_label, _, _ = point_selection(sim, topk=1)
print("좌표값",topk_xy)
# First-step prediction
masks, scores, logits, logits_high = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
multimask_output=True)
# 예측 점수 출력
# print("예측 점수 (scores):")
# for idx, score in enumerate(scores):
# print(f"Mask {idx + 1}: {score.item():.4f}")
# Weighted sum three-scale masks 세 가지 스케일의 마스크를 가중치 합산하는 과정
logits_high = logits_high * weights.unsqueeze(-1)
logit_high = logits_high.sum(0)
mask = (logit_high > 0).detach().cpu().numpy()
logits = logits * weights_np[..., None]
logit = logits.sum(0)
# Cascaded Post-refinement-1 모델의 세분화된 후처리 단계 중 첫 번째 단계
y, x = np.nonzero(mask)
x_min = x.min()
x_max = x.max()
y_min = y.min()
y_max = y.max()
input_box = np.array([x_min, y_min, x_max, y_max])
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
box=input_box[None, :],
mask_input=logit[None, :, :],
multimask_output=True)
best_idx = np.argmax(scores)
# Cascaded Post-refinement-2 모델의 세분화된 후처리 단계 중 두 번째 단계
y, x = np.nonzero(masks[best_idx])
x_min = x.min()
x_max = x.max()
y_min = y.min()
y_max = y.max()
input_box = np.array([x_min, y_min, x_max, y_max])
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
box=input_box[None, :],
mask_input=logits[best_idx: best_idx + 1, :, :],
multimask_output=True)
best_idx = np.argmax(scores)
final_mask = masks[best_idx]
# 결과를 JSON 형식으로 저장할 dictionary
result = {
"image": f"image_{test_image}", # 이미지를 구분할 수 있는 고유한 이름을 사용
"masks": [],
"scores": [],
"coordinates": []
}
for idx, (mask, score) in enumerate(zip(masks, scores)):
mask_coords = np.array(np.nonzero(mask)).T.tolist() # 마스크 좌표를 (y, x) 형식으로 추출
result["masks"].append(mask_coords)
result["scores"].append(score.item())
# 각 마스크에 대해 좌표 정보 추가
result["coordinates"].append(mask_coords)
# 각 마스크의 중심 좌표 계산
if mask_coords: # 좌표가 존재하는 경우
y_coords, x_coords = zip(*mask_coords)
center_y = int(np.mean(y_coords))
center_x = int(np.mean(x_coords))
# 이미지에 중심 좌표 표시
output_img = Image.fromarray((test_image).astype('uint8'), 'RGB')
draw = ImageDraw.Draw(output_img)
draw.text((center_x, center_y), f"({center_x}, {center_y})", fill=(255, 0, 0))
# 표시된 이미지를 출력
output_image.append(output_img)
segmentation_results.append(result)
# JSON 파일로 저장
with open("segmentation_results.json", "w") as f:
json.dump(segmentation_results, f, indent=4)
print("Segmentation results saved as 'segmentation_results.json'")
# 예측 점수 출력
print("예측 점수 (scores):")
for idx, score in enumerate(scores):
print(f"Mask {idx + 1}: {score.item():.4f}")
# Final mask의 좌표 추출
# y_coords, x_coords = np.nonzero(final_mask)
# # 좌표를 (y, x) 형식으로 묶어서 출력
# coordinates = list(zip(y_coords, x_coords))
# # 좌표 출력
# print("Segmentation된 좌표들:")
# for coord in coordinates:
# print(coord)
# Image 생성 및 점수 표시
output_img = Image.fromarray((test_image).astype('uint8'), 'RGB')
draw = ImageDraw.Draw(output_img)
# segmentation된 객체의 개수 계산
segmented_count = sum((mask.sum() > 0) for mask in masks) # 픽셀 합이 0보다 큰 경우 유효한 segmentation으로 간주
# draw.text((170, 10), f"Cnt: {segmented_count}", fill=(255, 0, 0)) # segmentation 개수 표기
# 신뢰도 점수를 마스크 영역 위에 표시
for idx, (mask, score) in enumerate(zip(masks, scores)):
y, x = np.nonzero(mask)
if len(x) > 0 and len(y) > 0: # 마스크가 비어있지 않을 때만 텍스트 표시
x_center = int(x.mean())
y_center = int(y.mean())
# draw.text((x_center, y_center), f"{score.item():.2f}", fill=(255, 255, 0))
# 최종 마스크 및 점수가 포함된 이미지를 리스트에 추가
mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
mask_colors[final_mask, :] = np.array([[128, 0, 0]])
# red 마스크 영역 외의 부분에 대해서 contour 및 bounding box 적용
test_image_np = np.array(test_image)
# 'final_mask' 외부를 마스크 영역으로 지정
final_mask_obj = final_mask.astype(np.uint8)
# inverse_mask에 대해서 컨투어 및 바운딩 박스를 그림
overlay_image, object_count = draw_contours_and_bboxes(test_image_np.copy(), final_mask_obj)
# 객체 개수 출력
print(f"Detected {object_count} objects in the background.")
# 최종 이미지 및 점수 표시
overlay_image = Image.fromarray(overlay_image)
# segmentation된 객체 개수를 다시 한번 표기 (이미지 우상단 등 다른 위치에)
draw_overlay = ImageDraw.Draw(overlay_image)
draw_overlay.text((170, 10), f"Cnt: {segmented_count}", fill=(255, 255, 0))
for idx, score in enumerate(scores):
draw_overlay.text((10, 10 + 20 * idx), f"Mask {idx + 1}: {score.item():.2f}", fill=(255, 255, 0))
output_image.append(overlay_image)
# overlay_image = Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB')
# draw_overlay = ImageDraw.Draw(overlay_image)
# # segmentation된 객체 개수를 다시 한번 표기 (이미지 우상단 등 다른 위치에)
# draw_overlay.text((170, 10), f"Cnt: {segmented_count}", fill=(255, 255, 0))
# for idx, score in enumerate(scores):
# draw_overlay.text((10, 10 + 20 * idx), f"Mask {idx + 1}: {score.item():.2f}", fill=(255, 255, 0))
# output_image.append(overlay_image)
# output_image.append(Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB'))
return output_image[0].resize((224, 224)), output_image[1].resize((224, 224)), output_image[2].resize((224, 224)), output_image[3].resize((224, 224))
description = """
<div style="text-align: center; font-weight: bold;">
<span style="font-size: 18px" id="paper-info">
[<a href="https://github.com/ZrrSkywalker/Personalize-SAM" target="_blank"><font color='black'>Github</font></a>]
[<a href="https://arxiv.org/pdf/2305.03048.pdf" target="_blank"><font color='black'>Paper</font></a>]
</span>
</div>
"""
main = gr.Interface(
fn=inference,
inputs=[
gr.Image(type="pil", label="in context image",),
gr.Image(type="pil", label="in context mask"),
gr.Image(type="pil", label="test image1"),
gr.Image(type="pil", label="test image2"),
],
outputs=[
gr.Image(type="pil", label="output image1"),
gr.Image(type="pil", label="output image2"),
],
allow_flagging="never",
title="Personalize Segment Anything Model with 1 Shot",
description=description,
examples=[
["./examples/cat_00.jpg", "./examples/cat_00.png", "./examples/cat_01.jpg", "./examples/cat_02.jpg"],
["./examples/colorful_sneaker_00.jpg", "./examples/colorful_sneaker_00.png", "./examples/colorful_sneaker_01.jpg", "./examples/colorful_sneaker_02.jpg"],
["./examples/duck_toy_00.jpg", "./examples/duck_toy_00.png", "./examples/duck_toy_01.jpg", "./examples/duck_toy_02.jpg"],
]
)
main_scribble = gr.Interface(
fn=inference_scribble,
inputs=[
gr.ImageMask(label="[Stroke] Draw on Image", type="pil"),
gr.Image(type="pil", label="test image1"),
gr.Image(type="pil", label="test image2"),
],
outputs=[
gr.Image(type="pil", label="output image1"),
gr.Image(type="pil", label="output image2"),
],
allow_flagging="never",
title="Personalize Segment Anything Model with 1 Shot",
description=description,
examples=[
["./examples/cat_00.jpg", "./examples/cat_01.jpg", "./examples/cat_02.jpg"],
["./examples/colorful_sneaker_00.jpg", "./examples/colorful_sneaker_01.jpg", "./examples/colorful_sneaker_02.jpg"],
["./examples/duck_toy_00.jpg", "./examples/duck_toy_01.jpg", "./examples/duck_toy_02.jpg"],
]
)
main_finetune_train = gr.Interface(
fn=inference_finetune_train,
inputs=[
gr.Image(type="pil", label="in context image"),
gr.Image(type="pil", label="in context mask"),
gr.Image(type="pil", label="test image1"),
gr.Image(type="pil", label="test image2"),
],
outputs=[
gr.components.Image(type="pil", label="output image1"),
gr.components.Image(type="pil", label="output image2"),
],
allow_flagging="never",
title="Personalize Segment Anything Model with 1 Shot Train",
description=description,
examples=[
["./examples/cat_00.jpg", "./examples/cat_00.png", "./examples/cat_01.jpg", "./examples/cat_02.jpg"],
["./examples/colorful_sneaker_00.jpg", "./examples/colorful_sneaker_00.png", "./examples/colorful_sneaker_01.jpg", "./examples/colorful_sneaker_02.jpg"],
["./examples/duck_toy_00.jpg", "./examples/duck_toy_00.png", "./examples/duck_toy_01.jpg", "./examples/duck_toy_02.jpg"],
]
)
main_finetune_test = gr.Interface(
fn=inference_finetune_test,
inputs=[
gr.Image(type="pil", label="test image1"),
gr.Image(type="pil", label="test image2"),
gr.Image(type="pil", label="test image3"),
gr.Image(type="pil", label="test image4"),
],
outputs=[
gr.components.Image(type="pil", label="output image1"),
gr.components.Image(type="pil", label="output image2"),
gr.components.Image(type="pil", label="output image3"),
gr.components.Image(type="pil", label="output image4"),
],
allow_flagging="never",
title="Personalize Segment Anything Model with 1 Shot Test",
description=description,
examples=[
["./examples/cat_00.jpg", "./examples/cat_00.png", "./examples/cat_01.jpg", "./examples/cat_02.jpg"],
["./examples/colorful_sneaker_00.jpg", "./examples/colorful_sneaker_00.png", "./examples/colorful_sneaker_01.jpg", "./examples/colorful_sneaker_02.jpg"],
["./examples/duck_toy_00.jpg", "./examples/duck_toy_00.png", "./examples/duck_toy_01.jpg", "./examples/duck_toy_02.jpg"],
]
)
demo = gr.Blocks()
with demo:
gr.TabbedInterface(
[main_finetune_train, main_finetune_test],
["Personalize-SAM-F_train", "Personalize-SAM-F_test"],
)
demo.launch(share=True)