Spaces:
Sleeping
Sleeping
init commit
Browse files- app.py +133 -0
- model.py +95 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from model import CycleGAN, get_val_transform, de_normalize
|
6 |
+
|
7 |
+
# Configure page
|
8 |
+
st.set_page_config(
|
9 |
+
page_title="CycleGAN Image Converter",
|
10 |
+
page_icon="🎨",
|
11 |
+
layout="wide"
|
12 |
+
)
|
13 |
+
|
14 |
+
# Get the best available device
|
15 |
+
@st.cache_resource
|
16 |
+
def get_device():
|
17 |
+
if torch.cuda.is_available():
|
18 |
+
device = torch.device("cuda")
|
19 |
+
st.sidebar.success("Using GPU 🚀")
|
20 |
+
elif torch.backends.mps.is_available():
|
21 |
+
device = torch.device("mps")
|
22 |
+
st.sidebar.success("Using Apple Silicon 🍎")
|
23 |
+
else:
|
24 |
+
device = torch.device("cpu")
|
25 |
+
st.sidebar.info("Using CPU 💻")
|
26 |
+
return device
|
27 |
+
|
28 |
+
# Add custom CSS
|
29 |
+
st.markdown("""
|
30 |
+
<style>
|
31 |
+
.stApp {
|
32 |
+
max-width: 1200px;
|
33 |
+
margin: 0 auto;
|
34 |
+
}
|
35 |
+
.main {
|
36 |
+
padding: 2rem;
|
37 |
+
}
|
38 |
+
</style>
|
39 |
+
""", unsafe_allow_html=True)
|
40 |
+
|
41 |
+
# Title and description
|
42 |
+
st.title("CycleGAN Image Converter 🎨")
|
43 |
+
st.markdown("""
|
44 |
+
Transform images between different domains using CycleGAN.
|
45 |
+
Upload an image and see it converted in real-time!
|
46 |
+
|
47 |
+
*Note: Images will be resized to 256x256 pixels during conversion.*
|
48 |
+
""")
|
49 |
+
|
50 |
+
# Available models and their configurations
|
51 |
+
MODELS = [
|
52 |
+
{
|
53 |
+
"name": "Cezanne ↔ Photo",
|
54 |
+
"id": "cezanne2photo",
|
55 |
+
"model_path": "waleko/cyclegan",
|
56 |
+
"description": "Convert between Cezanne's painting style and photographs"
|
57 |
+
}
|
58 |
+
]
|
59 |
+
|
60 |
+
# Sidebar controls
|
61 |
+
with st.sidebar:
|
62 |
+
st.header("Settings")
|
63 |
+
|
64 |
+
# Model selection
|
65 |
+
selected_model = st.selectbox(
|
66 |
+
"Conversion Type",
|
67 |
+
options=range(len(MODELS)),
|
68 |
+
format_func=lambda x: MODELS[x]["name"]
|
69 |
+
)
|
70 |
+
|
71 |
+
# Direction selection
|
72 |
+
direction = st.radio(
|
73 |
+
"Conversion Direction",
|
74 |
+
options=["A → B", "B → A"],
|
75 |
+
help="A → B: Convert from domain A to B\nB → A: Convert from domain B to A"
|
76 |
+
)
|
77 |
+
|
78 |
+
# Load model
|
79 |
+
@st.cache_resource
|
80 |
+
def load_model(model_path):
|
81 |
+
device = get_device()
|
82 |
+
model = CycleGAN.from_pretrained(model_path)
|
83 |
+
model = model.to(device)
|
84 |
+
model.eval()
|
85 |
+
return model
|
86 |
+
|
87 |
+
# Process image
|
88 |
+
def process_image(image, model, direction):
|
89 |
+
# Prepare transform
|
90 |
+
transform = get_val_transform(model, direction)
|
91 |
+
|
92 |
+
# Convert PIL image to tensor
|
93 |
+
tensor = transform(np.array(image)).unsqueeze(0)
|
94 |
+
|
95 |
+
# Move to appropriate device
|
96 |
+
tensor = tensor.to(next(model.parameters()).device)
|
97 |
+
|
98 |
+
# Process
|
99 |
+
with torch.no_grad():
|
100 |
+
if direction == "A → B":
|
101 |
+
output = model.generator_ab(tensor)
|
102 |
+
else:
|
103 |
+
output = model.generator_ba(tensor)
|
104 |
+
|
105 |
+
# Convert back to image
|
106 |
+
result = de_normalize(output[0], model, direction)
|
107 |
+
return result
|
108 |
+
|
109 |
+
# Main interface
|
110 |
+
col1, col2 = st.columns(2)
|
111 |
+
|
112 |
+
with col1:
|
113 |
+
st.subheader("Input Image")
|
114 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
115 |
+
|
116 |
+
if uploaded_file is not None:
|
117 |
+
input_image = Image.open(uploaded_file)
|
118 |
+
st.image(input_image, use_column_width=True)
|
119 |
+
|
120 |
+
with col2:
|
121 |
+
st.subheader("Converted Image")
|
122 |
+
if uploaded_file is not None:
|
123 |
+
try:
|
124 |
+
# Load and process
|
125 |
+
model = load_model(MODELS[selected_model]["model_path"])
|
126 |
+
result = process_image(input_image, model, direction)
|
127 |
+
|
128 |
+
# Display
|
129 |
+
st.image(result, use_column_width=True)
|
130 |
+
except Exception as e:
|
131 |
+
st.error(f"Error during conversion: {str(e)}")
|
132 |
+
else:
|
133 |
+
st.info("Upload an image to see the conversion result")
|
model.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from huggingface_hub import PyTorchModelHubMixin
|
4 |
+
import torchvision.transforms as tr
|
5 |
+
import functools
|
6 |
+
|
7 |
+
class ResidualBlock(nn.Module):
|
8 |
+
def __init__(self, in_features):
|
9 |
+
super(ResidualBlock, self).__init__()
|
10 |
+
|
11 |
+
self.block = nn.Sequential(
|
12 |
+
nn.ReflectionPad2d(1),
|
13 |
+
nn.Conv2d(in_features, in_features, 3),
|
14 |
+
nn.InstanceNorm2d(in_features),
|
15 |
+
nn.ReLU(inplace=True),
|
16 |
+
nn.ReflectionPad2d(1),
|
17 |
+
nn.Conv2d(in_features, in_features, 3),
|
18 |
+
nn.InstanceNorm2d(in_features),
|
19 |
+
)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
return x + self.block(x)
|
23 |
+
|
24 |
+
def generator(num_residual_blocks=9):
|
25 |
+
channels = 3
|
26 |
+
out_features = 64
|
27 |
+
model = [
|
28 |
+
nn.ReflectionPad2d(channels),
|
29 |
+
nn.Conv2d(channels, out_features, 7),
|
30 |
+
nn.InstanceNorm2d(out_features),
|
31 |
+
nn.ReLU(inplace=True),
|
32 |
+
]
|
33 |
+
in_features = out_features
|
34 |
+
|
35 |
+
# Downsampling
|
36 |
+
for _ in range(2):
|
37 |
+
out_features *= 2
|
38 |
+
model += [
|
39 |
+
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
40 |
+
nn.InstanceNorm2d(out_features),
|
41 |
+
nn.ReLU(inplace=True),
|
42 |
+
]
|
43 |
+
in_features = out_features
|
44 |
+
|
45 |
+
# Residual blocks
|
46 |
+
for _ in range(num_residual_blocks):
|
47 |
+
model += [ResidualBlock(out_features)]
|
48 |
+
|
49 |
+
# Upsampling
|
50 |
+
for _ in range(2):
|
51 |
+
out_features //= 2
|
52 |
+
model += [
|
53 |
+
nn.Upsample(scale_factor=2),
|
54 |
+
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
|
55 |
+
nn.InstanceNorm2d(out_features),
|
56 |
+
nn.ReLU(inplace=True),
|
57 |
+
]
|
58 |
+
in_features = out_features
|
59 |
+
|
60 |
+
# Output layer
|
61 |
+
model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
|
62 |
+
|
63 |
+
return nn.Sequential(*model)
|
64 |
+
|
65 |
+
class CycleGAN(nn.Module, PyTorchModelHubMixin, pipeline_tag="image-to-image"):
|
66 |
+
def __init__(self, channel_mean_a=None, channel_std_a=None, channel_mean_b=None, channel_std_b=None):
|
67 |
+
super(CycleGAN, self).__init__()
|
68 |
+
self.generator_ab = generator()
|
69 |
+
self.generator_ba = generator()
|
70 |
+
|
71 |
+
# Store normalization parameters as non-trainable parameters
|
72 |
+
self.register_buffer('channel_mean_a', torch.tensor(channel_mean_a if channel_mean_a is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
|
73 |
+
self.register_buffer('channel_std_a', torch.tensor(channel_std_a if channel_std_a is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
|
74 |
+
self.register_buffer('channel_mean_b', torch.tensor(channel_mean_b if channel_mean_b is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
|
75 |
+
self.register_buffer('channel_std_b', torch.tensor(channel_std_b if channel_std_b is not None else [0.5, 0.5, 0.5], dtype=torch.float32))
|
76 |
+
|
77 |
+
def get_val_transform(model, direction="a_to_b", size=256):
|
78 |
+
mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
|
79 |
+
std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
|
80 |
+
|
81 |
+
return tr.Compose([
|
82 |
+
tr.ToPILImage(),
|
83 |
+
tr.Resize(size),
|
84 |
+
tr.CenterCrop(size),
|
85 |
+
tr.ToTensor(),
|
86 |
+
tr.Normalize(mean=mean.tolist(), std=std.tolist()),
|
87 |
+
])
|
88 |
+
|
89 |
+
def de_normalize(tensor, model, direction="a_to_b"):
|
90 |
+
img_tensor = tensor.cpu().detach().clone()
|
91 |
+
mean = model.channel_mean_a if direction == "a_to_b" else model.channel_mean_b
|
92 |
+
std = model.channel_std_a if direction == "a_to_b" else model.channel_std_b
|
93 |
+
|
94 |
+
img_tensor = img_tensor * std[:, None, None] + mean[:, None, None]
|
95 |
+
return torch.clamp(img_tensor.permute(1, 2, 0), 0.0, 1.0)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit>=1.28.0
|
2 |
+
torch>=2.0.0
|
3 |
+
torchvision>=0.15.0
|
4 |
+
Pillow>=10.0.0
|
5 |
+
numpy>=1.24.0
|
6 |
+
huggingface-hub>=0.19.0
|