Spaces:
Runtime error
Runtime error
import os | |
import random | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | |
import matplotlib.pyplot as plt | |
import cv2 | |
import torch.nn.functional as F | |
class _bn_relu_conv(nn.Module): | |
def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): | |
super(_bn_relu_conv, self).__init__() | |
self.model = nn.Sequential( | |
nn.BatchNorm2d(in_filters, eps=1e-3), | |
nn.LeakyReLU(0.2), | |
nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros') | |
) | |
def forward(self, x): | |
return self.model(x) | |
# the following are for debugs | |
print("****", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape) | |
for i,layer in enumerate(self.model): | |
if i != 2: | |
x = layer(x) | |
else: | |
x = layer(x) | |
#x = nn.functional.pad(x, (1, 1, 1, 1), mode='constant', value=0) | |
print("____", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape) | |
print(x[0]) | |
return x | |
class _u_bn_relu_conv(nn.Module): | |
def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): | |
super(_u_bn_relu_conv, self).__init__() | |
self.model = nn.Sequential( | |
nn.BatchNorm2d(in_filters, eps=1e-3), | |
nn.LeakyReLU(0.2), | |
nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)), | |
nn.Upsample(scale_factor=2, mode='nearest') | |
) | |
def forward(self, x): | |
return self.model(x) | |
class _shortcut(nn.Module): | |
def __init__(self, in_filters, nb_filters, subsample=1): | |
super(_shortcut, self).__init__() | |
self.process = False | |
self.model = None | |
if in_filters != nb_filters or subsample != 1: | |
self.process = True | |
self.model = nn.Sequential( | |
nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample) | |
) | |
def forward(self, x, y): | |
#print(x.size(), y.size(), self.process) | |
if self.process: | |
y0 = self.model(x) | |
#print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape) | |
return y0 + y | |
else: | |
#print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape) | |
return x + y | |
class _u_shortcut(nn.Module): | |
def __init__(self, in_filters, nb_filters, subsample): | |
super(_u_shortcut, self).__init__() | |
self.process = False | |
self.model = None | |
if in_filters != nb_filters: | |
self.process = True | |
self.model = nn.Sequential( | |
nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'), | |
nn.Upsample(scale_factor=2, mode='nearest') | |
) | |
def forward(self, x, y): | |
if self.process: | |
return self.model(x) + y | |
else: | |
return x + y | |
class basic_block(nn.Module): | |
def __init__(self, in_filters, nb_filters, init_subsample=1): | |
super(basic_block, self).__init__() | |
self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) | |
self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) | |
self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample) | |
def forward(self, x): | |
x1 = self.conv1(x) | |
x2 = self.residual(x1) | |
return self.shortcut(x, x2) | |
class _u_basic_block(nn.Module): | |
def __init__(self, in_filters, nb_filters, init_subsample=1): | |
super(_u_basic_block, self).__init__() | |
self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) | |
self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) | |
self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample) | |
def forward(self, x): | |
y = self.residual(self.conv1(x)) | |
return self.shortcut(x, y) | |
class _residual_block(nn.Module): | |
def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False): | |
super(_residual_block, self).__init__() | |
layers = [] | |
for i in range(repetitions): | |
init_subsample = 1 | |
if i == repetitions - 1 and not is_first_layer: | |
init_subsample = 2 | |
if i == 0: | |
l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample) | |
else: | |
l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample) | |
layers.append(l) | |
self.model = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.model(x) | |
class _upsampling_residual_block(nn.Module): | |
def __init__(self, in_filters, nb_filters, repetitions): | |
super(_upsampling_residual_block, self).__init__() | |
layers = [] | |
for i in range(repetitions): | |
l = None | |
if i == 0: | |
l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input) | |
else: | |
l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input) | |
layers.append(l) | |
self.model = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.model(x) | |
class res_skip(nn.Module): | |
def __init__(self): | |
super(res_skip, self).__init__() | |
self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True)#(input) | |
self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3)#(block0) | |
self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5)#(block1) | |
self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7)#(block2) | |
self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12)#(block3) | |
self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7)#(block4) | |
self.res1 = _shortcut(in_filters=192, nb_filters=192)#(block3, block5, subsample=(1,1)) | |
self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5)#(res1) | |
self.res2 = _shortcut(in_filters=96, nb_filters=96)#(block2, block6, subsample=(1,1)) | |
self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3)#(res2) | |
self.res3 = _shortcut(in_filters=48, nb_filters=48)#(block1, block7, subsample=(1,1)) | |
self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2)#(res3) | |
self.res4 = _shortcut(in_filters=24, nb_filters=24)#(block0,block8, subsample=(1,1)) | |
self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True)#(res4) | |
self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1)#(block7) | |
def forward(self, x): | |
x0 = self.block0(x) | |
x1 = self.block1(x0) | |
x2 = self.block2(x1) | |
x3 = self.block3(x2) | |
x4 = self.block4(x3) | |
x5 = self.block5(x4) | |
res1 = self.res1(x3, x5) | |
x6 = self.block6(res1) | |
res2 = self.res2(x2, x6) | |
x7 = self.block7(res2) | |
res3 = self.res3(x1, x7) | |
x8 = self.block8(res3) | |
res4 = self.res4(x0, x8) | |
x9 = self.block9(res4) | |
y = self.conv15(x9) | |
return y | |
class MyDataset(Dataset): | |
def __init__(self, image_paths, transform=None): | |
self.image_paths = image_paths | |
self.transform = transform | |
def get_class_label(self, image_name): | |
# your method here | |
head, tail = os.path.split(image_name) | |
#print(tail) | |
return tail | |
def __getitem__(self, index): | |
image_path = self.image_paths[index] | |
x = Image.open(image_path) | |
y = self.get_class_label(image_path.split('/')[-1]) | |
if self.transform is not None: | |
x = self.transform(x) | |
return x, y | |
def __len__(self): | |
return len(self.image_paths) | |
def loadImages(folder): | |
imgs = [] | |
matches = [] | |
for filename in os.listdir(folder): | |
file_path = os.path.join(folder, filename) | |
if os.path.isfile(file_path): | |
matches.append(file_path) | |
return matches | |
def crop_center_square(image): | |
width, height = image.size | |
side_length = min(width, height) | |
left = (width - side_length) // 2 | |
top = (height - side_length) // 2 | |
right = left + side_length | |
bottom = top + side_length | |
cropped_image = image.crop((left, top, right, bottom)) | |
return cropped_image | |
def crop_image(image, crop_size, stride): | |
width, height = image.size | |
crop_width, crop_height = crop_size | |
cropped_images = [] | |
for j in range(0, height - crop_height + 1, stride): | |
for i in range(0, width - crop_width + 1, stride): | |
crop_box = (i, j, i + crop_width, j + crop_height) | |
cropped_image = image.crop(crop_box) | |
cropped_images.append(cropped_image) | |
return cropped_images | |
def process_image_ref(image): | |
resized_image_512 = image.resize((512, 512)) | |
image_list = [resized_image_512] | |
crop_size_384 = (384, 384) | |
stride_384 = 128 | |
image_list.extend(crop_image(resized_image_512, crop_size_384, stride_384)) | |
return image_list | |
def process_image_Q(image): | |
resized_image_512 = image.resize((512, 512)).convert("RGB").convert("RGB") | |
image_list = [] | |
crop_size_384 = (384, 384) | |
stride_384 = 128 | |
image_list.extend(crop_image(resized_image_512, crop_size_384, stride_384)) | |
return image_list | |
def process_image(image, target_width=512, target_height = 512): | |
img_width, img_height = image.size | |
img_ratio = img_width / img_height | |
target_ratio = target_width / target_height | |
ratio_error = abs(img_ratio - target_ratio) / target_ratio | |
if ratio_error < 0.15: | |
resized_image = image.resize((target_width, target_height), Image.BICUBIC) | |
else: | |
if img_ratio > target_ratio: | |
new_width = int(img_height * target_ratio) | |
left = int((0 + img_width - new_width)/2) | |
top = 0 | |
right = left + new_width | |
bottom = img_height | |
else: | |
new_height = int(img_width / target_ratio) | |
left = 0 | |
top = int((0 + img_height - new_height)/2) | |
right = img_width | |
bottom = top + new_height | |
cropped_image = image.crop((left, top, right, bottom)) | |
resized_image = cropped_image.resize((target_width, target_height), Image.BICUBIC) | |
return resized_image.convert('RGB') | |
def crop_image_varres(image, crop_size, h_stride, w_stride): | |
width, height = image.size | |
crop_width, crop_height = crop_size | |
cropped_images = [] | |
for j in range(0, height - crop_height + 1, h_stride): | |
for i in range(0, width - crop_width + 1, w_stride): | |
crop_box = (i, j, i + crop_width, j + crop_height) | |
cropped_image = image.crop(crop_box) | |
cropped_images.append(cropped_image) | |
return cropped_images | |
def process_image_ref_varres(image, target_width=512, target_height = 512): | |
resized_image_512 = image.resize((target_width, target_height)) | |
image_list = [resized_image_512] | |
crop_size_384 = (target_width//4*3, target_height//4*3) | |
w_stride_384 = target_width//4 | |
h_stride_384 = target_height//4 | |
image_list.extend(crop_image_varres(resized_image_512, crop_size_384, h_stride = h_stride_384, w_stride = w_stride_384)) | |
return image_list | |
def process_image_Q_varres(image, target_width=512, target_height = 512): | |
resized_image_512 = image.resize((target_width, target_height)).convert("RGB").convert("RGB") | |
image_list = [] | |
crop_size_384 = (target_width//4*3, target_height//4*3) | |
w_stride_384 = target_width//4 | |
h_stride_384 = target_height//4 | |
image_list.extend(crop_image_varres(resized_image_512, crop_size_384, h_stride = h_stride_384, w_stride = w_stride_384)) | |
return image_list | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class ResNetBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, stride=1): | |
super(ResNetBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) | |
self.bn2 = nn.BatchNorm2d(out_channels) | |
self.shortcut = nn.Sequential() | |
if stride != 1 or in_channels != out_channels: | |
self.shortcut = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), | |
nn.BatchNorm2d(out_channels) | |
) | |
def forward(self, x): | |
out = F.relu(self.bn1(self.conv1(x))) | |
out = self.bn2(self.conv2(out)) | |
out += self.shortcut(x) # 直接相加 | |
out = F.relu(out) | |
return out | |
class TwoLayerResNet(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(TwoLayerResNet, self).__init__() | |
self.block1 = ResNetBlock(in_channels, out_channels) | |
self.block2 = ResNetBlock(out_channels, out_channels) | |
self.block3 = ResNetBlock(out_channels, out_channels) | |
self.block4 = ResNetBlock(out_channels, out_channels) | |
def forward(self, x): | |
x = self.block1(x) | |
x = self.block2(x) | |
x = self.block3(x) | |
x = self.block4(x) | |
return x | |
class MultiHiddenResNetModel(nn.Module): | |
def __init__(self, channels_list, num_tensors): | |
super(MultiHiddenResNetModel, self).__init__() | |
self.two_layer_resnets = nn.ModuleList([TwoLayerResNet(channels_list[idx]*2, channels_list[min(len(channels_list)-1,idx+2)]) for idx in range(num_tensors)]) | |
def forward(self, tensor_list): | |
processed_list = [] | |
for i, tensor in enumerate(tensor_list): | |
tensor = self.two_layer_resnets[i](tensor) | |
processed_list.append(tensor) | |
return processed_list | |
def calculate_target_size(h, w): | |
if random.random()>0.5: | |
target_h = (h // 8) * 8 | |
target_w = (w // 8) * 8 | |
elif random.random()>0.5: | |
target_h = (h // 8) * 8 | |
target_w = (w // 8) * 8 | |
else: | |
target_h = (h // 8) * 8 | |
target_w = (w // 8) * 8 | |
if target_h == 0: | |
target_h = 8 | |
if target_w == 0: | |
target_w = 8 | |
return target_h, target_w | |
def downsample_tensor(tensor): | |
b, c, h, w = tensor.shape | |
target_h, target_w = calculate_target_size(h, w) | |
downsampled_tensor = F.interpolate(tensor, size=(target_h, target_w), mode='bilinear', align_corners=False) | |
return downsampled_tensor | |
def get_pixart_config(): | |
pixart_config = { | |
"_class_name": "Transformer2DModel", | |
"_diffusers_version": "0.22.0.dev0", | |
"activation_fn": "gelu-approximate", | |
"attention_bias": True, | |
"attention_head_dim": 72, | |
"attention_type": "default", | |
"caption_channels": 4096, | |
"cross_attention_dim": 1152, | |
"double_self_attention": False, | |
"dropout": 0.0, | |
"in_channels": 4, | |
# "interpolation_scale": 2, | |
"norm_elementwise_affine": False, | |
"norm_eps": 1e-06, | |
"norm_num_groups": 32, | |
"norm_type": "ada_norm_single", | |
"num_attention_heads": 16, | |
"num_embeds_ada_norm": 1000, | |
"num_layers": 28, | |
"num_vector_embeds": None, | |
"only_cross_attention": False, | |
"out_channels": 8, | |
"patch_size": 2, | |
"sample_size": 128, | |
"upcast_attention": False, | |
# "use_additional_conditions": False, | |
"use_linear_projection": False | |
} | |
return pixart_config | |
class DoubleConv(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.double_conv = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, 3, 1, 1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(), | |
nn.Conv2d(out_channels, out_channels, 3, 1, 1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU() | |
) | |
def forward(self, x): | |
return self.double_conv(x) | |
class UNet(nn.Module): | |
def __init__(self): | |
super().__init__() | |
# left | |
self.left_conv_1 = DoubleConv(6, 64) | |
self.down_1 = nn.MaxPool2d(2, 2) | |
self.left_conv_2 = DoubleConv(64, 128) | |
self.down_2 = nn.MaxPool2d(2, 2) | |
self.left_conv_3 = DoubleConv(128, 256) | |
self.down_3 = nn.MaxPool2d(2, 2) | |
self.left_conv_4 = DoubleConv(256, 512) | |
self.down_4 = nn.MaxPool2d(2, 2) | |
# center | |
self.center_conv = DoubleConv(512, 1024) | |
# right | |
self.up_1 = nn.ConvTranspose2d(1024, 512, 2, 2) | |
self.right_conv_1 = DoubleConv(1024, 512) | |
self.up_2 = nn.ConvTranspose2d(512, 256, 2, 2) | |
self.right_conv_2 = DoubleConv(512, 256) | |
self.up_3 = nn.ConvTranspose2d(256, 128, 2, 2) | |
self.right_conv_3 = DoubleConv(256, 128) | |
self.up_4 = nn.ConvTranspose2d(128, 64, 2, 2) | |
self.right_conv_4 = DoubleConv(128, 64) | |
# output | |
self.output = nn.Conv2d(64, 3, 1, 1, 0) | |
def forward(self, x): | |
# left | |
x1 = self.left_conv_1(x) | |
x1_down = self.down_1(x1) | |
x2 = self.left_conv_2(x1_down) | |
x2_down = self.down_2(x2) | |
x3 = self.left_conv_3(x2_down) | |
x3_down = self.down_3(x3) | |
x4 = self.left_conv_4(x3_down) | |
x4_down = self.down_4(x4) | |
# center | |
x5 = self.center_conv(x4_down) | |
# right | |
x6_up = self.up_1(x5) | |
temp = torch.cat((x6_up, x4), dim=1) | |
x6 = self.right_conv_1(temp) | |
x7_up = self.up_2(x6) | |
temp = torch.cat((x7_up, x3), dim=1) | |
x7 = self.right_conv_2(temp) | |
x8_up = self.up_3(x7) | |
temp = torch.cat((x8_up, x2), dim=1) | |
x8 = self.right_conv_3(temp) | |
x9_up = self.up_4(x8) | |
temp = torch.cat((x9_up, x1), dim=1) | |
x9 = self.right_conv_4(temp) | |
# output | |
output = self.output(x9) | |
return output | |
from copy import deepcopy | |
def init_causal_dit(model, base_model): | |
temp_ckpt = deepcopy(base_model) | |
checkpoint = temp_ckpt.state_dict() | |
# checkpoint['pos_embed_1d.weight'] = torch.zeros(3, model.config.num_attention_heads * model.config.attention_head_dim, device=model.pos_embed_1d.weight.device, dtype = model.pos_embed_1d.weight.dtype) | |
model.load_state_dict(checkpoint, strict=True) | |
del temp_ckpt | |
return model | |
def init_controlnet(model, base_model): | |
temp_ckpt = deepcopy(base_model) | |
checkpoint = temp_ckpt.state_dict() | |
checkpoint_weight = checkpoint['pos_embed.proj.weight'] | |
new_weight = torch.zeros(model.pos_embed.proj.weight.shape, device=model.pos_embed.proj.weight.device, dtype = model.pos_embed.proj.weight.dtype) | |
print('model.pos_embed.proj.weight.shape',model.pos_embed.proj.weight.shape) | |
new_weight[:, :4] = checkpoint_weight | |
checkpoint['pos_embed.proj.weight'] = new_weight | |
print('new_weight', new_weight.dtype) | |
model.load_state_dict(checkpoint, strict=False) | |
del temp_ckpt | |
return model |