waleko commited on
Commit
7cf938c
·
1 Parent(s): 20309fe

init commit

Browse files
Files changed (3) hide show
  1. app.py +133 -0
  2. model.py +95 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from model import CycleGAN, get_val_transform, de_normalize
6
+
7
+ # Configure page
8
+ st.set_page_config(
9
+ page_title="CycleGAN Image Converter",
10
+ page_icon="🎨",
11
+ layout="wide"
12
+ )
13
+
14
+ # Get the best available device
15
+ @st.cache_resource
16
+ def get_device():
17
+ if torch.cuda.is_available():
18
+ device = torch.device("cuda")
19
+ st.sidebar.success("Using GPU 🚀")
20
+ elif torch.backends.mps.is_available():
21
+ device = torch.device("mps")
22
+ st.sidebar.success("Using Apple Silicon 🍎")
23
+ else:
24
+ device = torch.device("cpu")
25
+ st.sidebar.info("Using CPU 💻")
26
+ return device
27
+
28
+ # Add custom CSS
29
+ st.markdown("""
30
+ <style>
31
+ .stApp {
32
+ max-width: 1200px;
33
+ margin: 0 auto;
34
+ }
35
+ .main {
36
+ padding: 2rem;
37
+ }
38
+ </style>
39
+ """, unsafe_allow_html=True)
40
+
41
+ # Title and description
42
+ st.title("CycleGAN Image Converter 🎨")
43
+ st.markdown("""
44
+ Transform images between different domains using CycleGAN.
45
+ Upload an image and see it converted in real-time!
46
+
47
+ *Note: Images will be resized to 256x256 pixels during conversion.*
48
+ """)
49
+
50
+ # Available models and their configurations
51
+ MODELS = [
52
+ {
53
+ "name": "Cezanne ↔ Photo",
54
+ "id": "cezanne2photo",
55
+ "model_path": "waleko/cyclegan",
56
+ "description": "Convert between Cezanne's painting style and photographs"
57
+ }
58
+ ]
59
+
60
+ # Sidebar controls
61
+ with st.sidebar:
62
+ st.header("Settings")
63
+
64
+ # Model selection
65
+ selected_model = st.selectbox(
66
+ "Conversion Type",
67
+ options=range(len(MODELS)),
68
+ format_func=lambda x: MODELS[x]["name"]
69
+ )
70
+
71
+ # Direction selection
72
+ direction = st.radio(
73
+ "Conversion Direction",
74
+ options=["A → B", "B → A"],
75
+ help="A → B: Convert from domain A to B\nB → A: Convert from domain B to A"
76
+ )
77
+
78
+ # Load model
79
+ @st.cache_resource
80
+ def load_model(model_path):
81
+ device = get_device()
82
+ model = CycleGAN.from_pretrained(model_path)
83
+ model = model.to(device)
84
+ model.eval()
85
+ return model
86
+
87
+ # Process image
88
+ def process_image(image, model, direction):
89
+ # Prepare transform
90
+ transform = get_val_transform(model, direction)
91
+
92
+ # Convert PIL image to tensor
93
+ tensor = transform(np.array(image)).unsqueeze(0)
94
+
95
+ # Move to appropriate device
96
+ tensor = tensor.to(next(model.parameters()).device)
97
+
98
+ # Process
99
+ with torch.no_grad():
100
+ if direction == "A → B":
101
+ output = model.generator_ab(tensor)
102
+ else:
103
+ output = model.generator_ba(tensor)
104
+
105
+ # Convert back to image
106
+ result = de_normalize(output[0], model, direction)
107
+ return result
108
+
109
+ # Main interface
110
+ col1, col2 = st.columns(2)
111
+
112
+ with col1:
113
+ st.subheader("Input Image")
114
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
115
+
116
+ if uploaded_file is not None:
117
+ input_image = Image.open(uploaded_file)
118
+ st.image(input_image, use_column_width=True)
119
+
120
+ with col2:
121
+ st.subheader("Converted Image")
122
+ if uploaded_file is not None:
123
+ try:
124
+ # Load and process
125
+ model = load_model(MODELS[selected_model]["model_path"])
126
+ result = process_image(input_image, model, direction)
127
+
128
+ # Display
129
+ st.image(result, use_column_width=True)
130
+ except Exception as e:
131
+ st.error(f"Error during conversion: {str(e)}")
132
+ else:
133
+ st.info("Upload an image to see the conversion result")
model.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+ import torchvision.transforms as tr
5
+ import functools
6
+
7
+ class ResidualBlock(nn.Module):
8
+ def __init__(self, in_features):
9
+ super(ResidualBlock, self).__init__()
10
+
11
+ self.block = nn.Sequential(
12
+ nn.ReflectionPad2d(1),
13
+ nn.Conv2d(in_features, in_features, 3),
14
+ nn.InstanceNorm2d(in_features),
15
+ nn.ReLU(inplace=True),
16
+ nn.ReflectionPad2d(1),
17
+ nn.Conv2d(in_features, in_features, 3),
18
+ nn.InstanceNorm2d(in_features),
19
+ )
20
+
21
+ def forward(self, x):
22
+ return x + self.block(x)
23
+
24
+ def generator(num_residual_blocks=9):
25
+ channels = 3
26
+ out_features = 64
27
+ model = [
28
+ nn.ReflectionPad2d(channels),
29
+ nn.Conv2d(channels, out_features, 7),
30
+ nn.InstanceNorm2d(out_features),
31
+ nn.ReLU(inplace=True),
32
+ ]
33
+ in_features = out_features
34
+
35
+ # Downsampling
36
+ for _ in range(2):
37
+ out_features *= 2
38
+ model += [
39
+ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
40
+ nn.InstanceNorm2d(out_features),
41
+ nn.ReLU(inplace=True),
42
+ ]
43
+ in_features = out_features
44
+
45
+ # Residual blocks
46
+ for _ in range(num_residual_blocks):
47
+ model += [ResidualBlock(out_features)]
48
+
49
+ # Upsampling
50
+ for _ in range(2):
51
+ out_features //= 2
52
+ model += [
53
+ nn.Upsample(scale_factor=2),
54
+ nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
55
+ nn.InstanceNorm2d(out_features),
56
+ nn.ReLU(inplace=True),
57
+ ]
58
+ in_features = out_features
59
+
60
+ # Output layer
61
+ model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
62
+
63
+ return nn.Sequential(*model)
64
+
65
+ class CycleGAN(nn.Module, PyTorchModelHubMixin, pipeline_tag="image-to-image"):
66
+ def __init__(self, channel_mean_a=None, channel_std_a=None, channel_mean_b=None, channel_std_b=None):
67
+ super(CycleGAN, self).__init__()
68
+ self.generator_ab = generator()
69
+ self.generator_ba = generator()
70
+
71
+ # Store normalization parameters as non-trainable parameters
72
+ self.register_buffer('channel_mean_a', torch.tensor(channel_mean_a if channel_mean_a is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
73
+ self.register_buffer('channel_std_a', torch.tensor(channel_std_a if channel_std_a is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
74
+ self.register_buffer('channel_mean_b', torch.tensor(channel_mean_b if channel_mean_b is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
75
+ self.register_buffer('channel_std_b', torch.tensor(channel_std_b if channel_std_b is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
76
+
77
+ def get_val_transform(model, direction="a_to_b", size=256):
78
+ mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
79
+ std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
80
+
81
+ return tr.Compose([
82
+ tr.ToPILImage(),
83
+ tr.Resize(size),
84
+ tr.CenterCrop(size),
85
+ tr.ToTensor(),
86
+ tr.Normalize(mean=mean.tolist(), std=std.tolist()),
87
+ ])
88
+
89
+ def de_normalize(tensor, model, direction="a_to_b"):
90
+ img_tensor = tensor.cpu().detach().clone()
91
+ mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
92
+ std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
93
+
94
+ img_tensor = img_tensor * std[:, None, None] + mean[:, None, None]
95
+ return torch.clamp(img_tensor.permute(1, 2, 0), 0.0, 1.0)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit>=1.28.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ Pillow>=10.0.0
5
+ numpy>=1.24.0
6
+ huggingface-hub>=0.19.0