waleko's picture
add daynight
dd6e157
import streamlit as st
import torch
import numpy as np
from PIL import Image
from model import CycleGAN, get_val_transform, de_normalize
st.set_page_config(
page_title="CycleGAN Image Converter",
page_icon="🎨",
layout="wide"
)
@st.cache_resource
def get_device():
if torch.cuda.is_available():
device = torch.device("cuda")
st.sidebar.success("Using GPU πŸš€")
elif torch.backends.mps.is_available():
device = torch.device("mps")
st.sidebar.success("Using Apple Silicon 🍎")
else:
device = torch.device("cpu")
st.sidebar.info("Using CPU πŸ’»")
return device
st.markdown("""
<style>
.stApp {
max-width: 1200px;
margin: 0 auto;
}
.main {
padding: 2rem;
}
</style>
""", unsafe_allow_html=True)
st.title("CycleGAN Image Converter 🎨")
st.markdown("""
Transform images between different domains using CycleGAN.
Upload an image and see it converted in real-time!
*Note: Images will be resized to 256x256 pixels during conversion.*
""")
MODELS = [
{
"name": "Cezanne ↔ Photo",
"id": "cezanne2photo",
"model_path": "waleko/cyclegan",
"description": "Convert between Cezanne's painting style and photographs"
},
{
"name": "Day ↔ Night",
"id": "day2night",
"model_path": "waleko/cyclegan-day_night",
"description": "Convert between day and night cityscapes"
}
]
with st.sidebar:
st.header("Settings")
selected_model = st.selectbox(
"Conversion Type",
options=range(len(MODELS)),
format_func=lambda x: MODELS[x]["name"]
)
direction = st.radio(
"Conversion Direction",
options=["A β†’ B", "B β†’ A"],
help="A β†’ B: Convert from domain A to B\nB β†’ A: Convert from domain B to A"
)
@st.cache_resource
def load_model(model_path):
device = get_device()
model = CycleGAN.from_pretrained(model_path)
model = model.to(device)
model.eval()
return model
def process_image(image, model, direction):
transform = get_val_transform(model, direction)
tensor = transform(np.array(image)).unsqueeze(0)
tensor = tensor.to(next(model.parameters()).device)
with torch.no_grad():
if direction == "A β†’ B":
output = model.generator_ab(tensor)
else:
output = model.generator_ba(tensor)
result = de_normalize(output[0], model, direction)
image = Image.fromarray(result.cpu().detach().numpy(), 'RGB')
return image
col1, col2 = st.columns(2)
with col1:
st.subheader("Input Image")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
input_image = Image.open(uploaded_file)
if input_image.mode == 'RGBA':
input_image = input_image.convert('RGB')
st.image(input_image)
with col2:
st.subheader("Converted Image")
if uploaded_file is not None:
try:
model = load_model(MODELS[selected_model]["model_path"])
result = process_image(input_image, model, direction)
st.image(result)
except Exception as e:
st.error(f"Error during conversion: {str(e)} {e.__traceback__}")
raise
else:
st.info("Upload an image to see the conversion result")