Dattatreya commited on
Commit
67953fe
·
verified ·
1 Parent(s): 8a54ae0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import numpy as np
5
+ import gradio as gr
6
+ from torch import optim
7
+ import torchvision
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ def create_vgg_model():
12
+ model_weights = torchvision.models.VGG19_Weights.DEFAULT
13
+ model = torchvision.models.vgg19(weights=model_weights)
14
+ for param in model.parameters():
15
+ param.requires_grad = False
16
+ model = model.features
17
+ return model
18
+
19
+ def preprocess(img):
20
+ image = Image.fromarray(img).convert('RGB')
21
+ imsize = 196
22
+ transform = transforms.Compose([
23
+ transforms.Resize((imsize, imsize)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26
+ ])
27
+ image = transform(image)
28
+ image = image.unsqueeze(dim=0)
29
+ return image
30
+
31
+ def deprocess(image):
32
+ image = image.clone()
33
+ image = image.squeeze(0)
34
+ image = image.permute(1, 2, 0)
35
+ image = image.cpu().detach().numpy()
36
+ image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
37
+ image = image.clip(0, 1)
38
+ return image
39
+
40
+ def get_features(image, model):
41
+ features = {}
42
+ layers = {
43
+ '0': 'layer_1',
44
+ '5': 'layer_2',
45
+ '10': 'layer_3',
46
+ '19': 'layer_4',
47
+ '28': 'layer_5'
48
+ }
49
+ x = image
50
+ for name, layer in model._modules.items():
51
+ x = layer(x)
52
+ if name in layers:
53
+ features[layers[name]] = x
54
+ return features
55
+
56
+ def gram_matrix(image):
57
+ b, c, h, w = image.size()
58
+ image = image.view(c, h * w)
59
+ gram = torch.mm(image, image.t())
60
+ return gram
61
+
62
+ def content_loss(target, content):
63
+ return torch.mean((target - content) ** 2)
64
+
65
+ def style_loss(target_features, style_grams):
66
+ loss = 0
67
+ for layer in target_features:
68
+ target_gram = gram_matrix(target_features[layer])
69
+ style_gram = style_grams[layer]
70
+ layer_style_loss = torch.mean((target_gram - style_gram) ** 2)
71
+ loss += layer_style_loss
72
+ return loss
73
+
74
+ def total_loss(content_loss, style_loss, alpha, beta):
75
+ return alpha * content_loss + beta * style_loss
76
+
77
+ def predict(content_image, style_image):
78
+ model = create_vgg_model().to(device).eval()
79
+ content_img = preprocess(content_image).to(device)
80
+ style_img = preprocess(style_image).to(device)
81
+ target_img = content_img.clone().requires_grad_(True)
82
+ content_features = get_features(content_img, model)
83
+ style_features = get_features(style_img, model)
84
+ style_gram = {layer: gram_matrix(style_features[layer]) for layer in style_features}
85
+ optimizer = optim.Adam([target_img], lr=0.06)
86
+ alpha_param = 1
87
+ beta_param = 1e2
88
+ epochs = 60
89
+ for i in range(epochs):
90
+ target_features = get_features(target_img, model)
91
+ c_loss = content_loss(target_features['layer_4'], content_features['layer_4'])
92
+ s_loss = style_loss(target_features, style_gram)
93
+ t_loss = total_loss(c_loss, s_loss, alpha_param, beta_param)
94
+ optimizer.zero_grad()
95
+ t_loss.backward()
96
+ optimizer.step()
97
+ results = deprocess(target_img)
98
+ return Image.fromarray((results * 255).astype(np.uint8))
99
+
100
+ title = "Neural Style Transfer 🎨"
101
+
102
+ demo = gr.Interface(fn=predict,
103
+ inputs=['image', 'image'],
104
+ outputs=gr.Image(),
105
+ title=title)
106
+
107
+ demo.launch(debug=False, share=False)
108
+