File size: 3,483 Bytes
7cf938c
 
 
 
 
 
1121140
7cf938c
 
 
 
 
 
1121140
7cf938c
 
 
 
 
 
 
 
 
 
 
 
 
1121140
7cf938c
 
 
 
 
 
 
 
 
 
 
 
1121140
7cf938c
 
 
 
 
 
 
 
1121140
7cf938c
 
 
 
 
 
dd6e157
 
 
 
 
 
7cf938c
 
 
1121140
7cf938c
 
 
1121140
7cf938c
 
 
 
 
 
1121140
7cf938c
 
 
 
 
 
1121140
7cf938c
 
 
 
 
 
 
 
1121140
7cf938c
 
 
 
 
 
 
 
 
 
1121140
 
 
7cf938c
 
 
 
 
 
 
 
 
1121140
 
 
 
7cf938c
 
 
 
 
1121140
7cf938c
 
 
1121140
 
7cf938c
1121140
 
7cf938c
1121140
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import streamlit as st
import torch
import numpy as np
from PIL import Image
from model import CycleGAN, get_val_transform, de_normalize


st.set_page_config(
    page_title="CycleGAN Image Converter",
    page_icon="🎨",
    layout="wide"
)


@st.cache_resource
def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        st.sidebar.success("Using GPU πŸš€")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        st.sidebar.success("Using Apple Silicon 🍎")
    else:
        device = torch.device("cpu")
        st.sidebar.info("Using CPU πŸ’»")
    return device


st.markdown("""
    <style>
    .stApp {
        max-width: 1200px;
        margin: 0 auto;
    }
    .main {
        padding: 2rem;
    }
    </style>
""", unsafe_allow_html=True)


st.title("CycleGAN Image Converter 🎨")
st.markdown("""
    Transform images between different domains using CycleGAN.
    Upload an image and see it converted in real-time!
    
    *Note: Images will be resized to 256x256 pixels during conversion.*
""")


MODELS = [
    {
        "name": "Cezanne ↔ Photo",
        "id": "cezanne2photo",
        "model_path": "waleko/cyclegan",
        "description": "Convert between Cezanne's painting style and photographs"
    },
    {
        "name": "Day ↔ Night",
        "id": "day2night",
        "model_path": "waleko/cyclegan-day_night",
        "description": "Convert between day and night cityscapes"
    }
]


with st.sidebar:
    st.header("Settings")
    
    
    selected_model = st.selectbox(
        "Conversion Type",
        options=range(len(MODELS)),
        format_func=lambda x: MODELS[x]["name"]
    )
    
    
    direction = st.radio(
        "Conversion Direction",
        options=["A β†’ B", "B β†’ A"],
        help="A β†’ B: Convert from domain A to B\nB β†’ A: Convert from domain B to A"
    )


@st.cache_resource
def load_model(model_path):
    device = get_device()
    model = CycleGAN.from_pretrained(model_path)
    model = model.to(device)
    model.eval()
    return model


def process_image(image, model, direction):
    transform = get_val_transform(model, direction)
    tensor = transform(np.array(image)).unsqueeze(0)
    tensor = tensor.to(next(model.parameters()).device)
    with torch.no_grad():
        if direction == "A β†’ B":
            output = model.generator_ab(tensor)
        else:
            output = model.generator_ba(tensor)
    result = de_normalize(output[0], model, direction)
    image = Image.fromarray(result.cpu().detach().numpy(), 'RGB')
    return image


col1, col2 = st.columns(2)

with col1:
    st.subheader("Input Image")
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
    
    if uploaded_file is not None:
        input_image = Image.open(uploaded_file)
        
        if input_image.mode == 'RGBA':
            input_image = input_image.convert('RGB')
        st.image(input_image)

with col2:
    st.subheader("Converted Image")
    if uploaded_file is not None:
        try:
            
            model = load_model(MODELS[selected_model]["model_path"])
            result = process_image(input_image, model, direction)
            
            
            st.image(result)
        except Exception as e:
            st.error(f"Error during conversion: {str(e)} {e.__traceback__}")
            raise
    else:
        st.info("Upload an image to see the conversion result")