Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import numpy as np | |
from PIL import Image | |
from model import CycleGAN, get_val_transform, de_normalize | |
st.set_page_config( | |
page_title="CycleGAN Image Converter", | |
page_icon="π¨", | |
layout="wide" | |
) | |
def get_device(): | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
st.sidebar.success("Using GPU π") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
st.sidebar.success("Using Apple Silicon π") | |
else: | |
device = torch.device("cpu") | |
st.sidebar.info("Using CPU π»") | |
return device | |
st.markdown(""" | |
<style> | |
.stApp { | |
max-width: 1200px; | |
margin: 0 auto; | |
} | |
.main { | |
padding: 2rem; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.title("CycleGAN Image Converter π¨") | |
st.markdown(""" | |
Transform images between different domains using CycleGAN. | |
Upload an image and see it converted in real-time! | |
*Note: Images will be resized to 256x256 pixels during conversion.* | |
""") | |
MODELS = [ | |
{ | |
"name": "Cezanne β Photo", | |
"id": "cezanne2photo", | |
"model_path": "waleko/cyclegan", | |
"description": "Convert between Cezanne's painting style and photographs" | |
}, | |
{ | |
"name": "Day β Night", | |
"id": "day2night", | |
"model_path": "waleko/cyclegan-day_night", | |
"description": "Convert between day and night cityscapes" | |
} | |
] | |
with st.sidebar: | |
st.header("Settings") | |
selected_model = st.selectbox( | |
"Conversion Type", | |
options=range(len(MODELS)), | |
format_func=lambda x: MODELS[x]["name"] | |
) | |
direction = st.radio( | |
"Conversion Direction", | |
options=["A β B", "B β A"], | |
help="A β B: Convert from domain A to B\nB β A: Convert from domain B to A" | |
) | |
def load_model(model_path): | |
device = get_device() | |
model = CycleGAN.from_pretrained(model_path) | |
model = model.to(device) | |
model.eval() | |
return model | |
def process_image(image, model, direction): | |
transform = get_val_transform(model, direction) | |
tensor = transform(np.array(image)).unsqueeze(0) | |
tensor = tensor.to(next(model.parameters()).device) | |
with torch.no_grad(): | |
if direction == "A β B": | |
output = model.generator_ab(tensor) | |
else: | |
output = model.generator_ba(tensor) | |
result = de_normalize(output[0], model, direction) | |
image = Image.fromarray(result.cpu().detach().numpy(), 'RGB') | |
return image | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Input Image") | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
input_image = Image.open(uploaded_file) | |
if input_image.mode == 'RGBA': | |
input_image = input_image.convert('RGB') | |
st.image(input_image) | |
with col2: | |
st.subheader("Converted Image") | |
if uploaded_file is not None: | |
try: | |
model = load_model(MODELS[selected_model]["model_path"]) | |
result = process_image(input_image, model, direction) | |
st.image(result) | |
except Exception as e: | |
st.error(f"Error during conversion: {str(e)} {e.__traceback__}") | |
raise | |
else: | |
st.info("Upload an image to see the conversion result") | |