import streamlit as st import torch import numpy as np from PIL import Image from model import CycleGAN, get_val_transform, de_normalize st.set_page_config( page_title="CycleGAN Image Converter", page_icon="🎨", layout="wide" ) @st.cache_resource def get_device(): if torch.cuda.is_available(): device = torch.device("cuda") st.sidebar.success("Using GPU 🚀") elif torch.backends.mps.is_available(): device = torch.device("mps") st.sidebar.success("Using Apple Silicon 🍎") else: device = torch.device("cpu") st.sidebar.info("Using CPU 💻") return device st.markdown(""" """, unsafe_allow_html=True) st.title("CycleGAN Image Converter 🎨") st.markdown(""" Transform images between different domains using CycleGAN. Upload an image and see it converted in real-time! *Note: Images will be resized to 256x256 pixels during conversion.* """) MODELS = [ { "name": "Cezanne ↔ Photo", "id": "cezanne2photo", "model_path": "waleko/cyclegan", "description": "Convert between Cezanne's painting style and photographs" }, { "name": "Day ↔ Night", "id": "day2night", "model_path": "waleko/cyclegan-day_night", "description": "Convert between day and night cityscapes" } ] with st.sidebar: st.header("Settings") selected_model = st.selectbox( "Conversion Type", options=range(len(MODELS)), format_func=lambda x: MODELS[x]["name"] ) direction = st.radio( "Conversion Direction", options=["A → B", "B → A"], help="A → B: Convert from domain A to B\nB → A: Convert from domain B to A" ) @st.cache_resource def load_model(model_path): device = get_device() model = CycleGAN.from_pretrained(model_path) model = model.to(device) model.eval() return model def process_image(image, model, direction): transform = get_val_transform(model, direction) tensor = transform(np.array(image)).unsqueeze(0) tensor = tensor.to(next(model.parameters()).device) with torch.no_grad(): if direction == "A → B": output = model.generator_ab(tensor) else: output = model.generator_ba(tensor) result = de_normalize(output[0], model, direction) image = Image.fromarray(result.cpu().detach().numpy(), 'RGB') return image col1, col2 = st.columns(2) with col1: st.subheader("Input Image") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: input_image = Image.open(uploaded_file) if input_image.mode == 'RGBA': input_image = input_image.convert('RGB') st.image(input_image) with col2: st.subheader("Converted Image") if uploaded_file is not None: try: model = load_model(MODELS[selected_model]["model_path"]) result = process_image(input_image, model, direction) st.image(result) except Exception as e: st.error(f"Error during conversion: {str(e)} {e.__traceback__}") raise else: st.info("Upload an image to see the conversion result")