Spaces:
Sleeping
Sleeping
File size: 3,483 Bytes
7cf938c 1121140 7cf938c 1121140 7cf938c 1121140 7cf938c 1121140 7cf938c 1121140 7cf938c dd6e157 7cf938c 1121140 7cf938c 1121140 7cf938c 1121140 7cf938c 1121140 7cf938c 1121140 7cf938c 1121140 7cf938c 1121140 7cf938c 1121140 7cf938c 1121140 7cf938c 1121140 7cf938c 1121140 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import streamlit as st
import torch
import numpy as np
from PIL import Image
from model import CycleGAN, get_val_transform, de_normalize
st.set_page_config(
page_title="CycleGAN Image Converter",
page_icon="π¨",
layout="wide"
)
@st.cache_resource
def get_device():
if torch.cuda.is_available():
device = torch.device("cuda")
st.sidebar.success("Using GPU π")
elif torch.backends.mps.is_available():
device = torch.device("mps")
st.sidebar.success("Using Apple Silicon π")
else:
device = torch.device("cpu")
st.sidebar.info("Using CPU π»")
return device
st.markdown("""
<style>
.stApp {
max-width: 1200px;
margin: 0 auto;
}
.main {
padding: 2rem;
}
</style>
""", unsafe_allow_html=True)
st.title("CycleGAN Image Converter π¨")
st.markdown("""
Transform images between different domains using CycleGAN.
Upload an image and see it converted in real-time!
*Note: Images will be resized to 256x256 pixels during conversion.*
""")
MODELS = [
{
"name": "Cezanne β Photo",
"id": "cezanne2photo",
"model_path": "waleko/cyclegan",
"description": "Convert between Cezanne's painting style and photographs"
},
{
"name": "Day β Night",
"id": "day2night",
"model_path": "waleko/cyclegan-day_night",
"description": "Convert between day and night cityscapes"
}
]
with st.sidebar:
st.header("Settings")
selected_model = st.selectbox(
"Conversion Type",
options=range(len(MODELS)),
format_func=lambda x: MODELS[x]["name"]
)
direction = st.radio(
"Conversion Direction",
options=["A β B", "B β A"],
help="A β B: Convert from domain A to B\nB β A: Convert from domain B to A"
)
@st.cache_resource
def load_model(model_path):
device = get_device()
model = CycleGAN.from_pretrained(model_path)
model = model.to(device)
model.eval()
return model
def process_image(image, model, direction):
transform = get_val_transform(model, direction)
tensor = transform(np.array(image)).unsqueeze(0)
tensor = tensor.to(next(model.parameters()).device)
with torch.no_grad():
if direction == "A β B":
output = model.generator_ab(tensor)
else:
output = model.generator_ba(tensor)
result = de_normalize(output[0], model, direction)
image = Image.fromarray(result.cpu().detach().numpy(), 'RGB')
return image
col1, col2 = st.columns(2)
with col1:
st.subheader("Input Image")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
input_image = Image.open(uploaded_file)
if input_image.mode == 'RGBA':
input_image = input_image.convert('RGB')
st.image(input_image)
with col2:
st.subheader("Converted Image")
if uploaded_file is not None:
try:
model = load_model(MODELS[selected_model]["model_path"])
result = process_image(input_image, model, direction)
st.image(result)
except Exception as e:
st.error(f"Error during conversion: {str(e)} {e.__traceback__}")
raise
else:
st.info("Upload an image to see the conversion result")
|