File size: 6,619 Bytes
da70cd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import io
import ssl
import base64
import torch
import streamlit as st
import urllib.request
import psutil
from PIL import Image
from process import process

st.set_page_config("Ai抠图(RMBG 2.0)", layout="wide")

st.markdown(
    """<style>

.stDeployButton {

    visibility: hidden;

}

.block-container {

    padding: 3rem 2rem 2rem 2rem;

}

.st-emotion-cache-1mi2ry5 {

    padding: 0rem 1rem;

}

.st-emotion-cache-1gwvy71 {

    padding: 1rem 1rem;

}

</style>""",
    unsafe_allow_html=True,
)

state = st.session_state
if "image" not in state:
    state.image = ""
if "image_nbg" not in state:
    state.image_nbg = ""
if "mask" not in state:
    state.mask = ""
if "filename" not in state:
    state.filename = ""
if "image_stream" not in state:
    state.image_stream = None
if "read_file_once" not in state:
    state.read_file_once = 0


IMAGE_FORMATS = ("jpg", "png", "jpeg", "JPG", "PNG", "JPEG")
DEVICE = "GPU" if torch.cuda.is_available() else "CPU"


@st.dialog("上传图片")
def upload_image(input_image_ph, output_image_ph):

    # 网络图片
    st.markdown("**图片链接**", help="填写网络图片地址")
    cls = st.columns([0.8, 0.2])
    url = cls[0].text_input(
        "xxx", placeholder="url/base64...", label_visibility="collapsed"
    )
    if cls[1].button(
        "读取",
        use_container_width=True,
        # disabled=not url or not url.startswith(("https://", "data:image/")),
    ):
        try:
            if url.startswith(("https://", "http://")):
                content = ssl._create_unverified_context()
                with urllib.request.urlopen(url, context=content) as response:
                    image_data = response.read()
                    state.image_stream = io.BytesIO(image_data)
                    name = "image." + url.rsplit(".", 1)[-1]
            elif url.startswith("data:image/"):
                pfix, base64_data = url.split(",", 1)
                state.image_stream = io.BytesIO(base64.b64decode(base64_data))
                name = "image." + pfix[11:-7]
            else:
                st.warning(":red[请输入有效的图片链接]")
        except Exception as e:
            st.warning(f":red[**读取图片失败,请保存到本地后上传**]")
            return

    # 本地图片
    def _cb():
        state.read_file_once = 1

    st.markdown("**上传图片**")
    file = st.file_uploader(
        "xxx",
        accept_multiple_files=False,
        type=IMAGE_FORMATS,
        label_visibility="collapsed",
        on_change=_cb,
        key="upload_key",
    )
    if state.read_file_once and file:
        state.image_stream = io.BytesIO(file.getvalue())
        name = file.name
        state.read_file_once = 0

    if state.image_stream is not None:
        try:
            image = Image.open(state.image_stream)
            state.image = image
            state.mask = ""
            state.image_nbg = ""
            state.filename = name

            input_image_ph.image(image)
            output_image_ph.empty()
            st.success(":rainbow[**上传成功**]")
        except Exception as e:
            st.warning(f":red[处理图片出错 >> {e}]")

        state.image_stream = None


@st.dialog("下载图片")
def download_image():
    if not state.mask or not state.image_nbg:
        st.warning("请上传图片")
    else:
        with st.spinner("正在处理中..."):
            buffer1 = io.BytesIO()
            state.mask.save(buffer1, format="PNG")
            buffer2 = io.BytesIO()
            state.image_nbg.save(buffer2, format="PNG")

        name = state.filename.rsplit(".", 1)[0] + "-mask.png"
        st.download_button(
            "下载掩码图片",
            data=buffer1.getvalue(),
            file_name=name,
            use_container_width=True,
            disabled=not state.mask,
        )

        name = state.filename.rsplit(".", 1)[0] + "-no-bg.png"
        st.download_button(
            "下载前景图片",
            data=buffer2.getvalue(),
            file_name=name,
            use_container_width=True,
            disabled=not state.image_nbg,
        )


def main():
    st.markdown(
        '<h1 style="text-align: center; color: white; background: #4b4bff; font-size: 26px; border-radius: .5rem; margin-bottom: 15px;">Ai抠图 (RMBG 2.0)</h1>',
        unsafe_allow_html=True,
    )
    body_cls = st.columns(2)

    body_cls[0].markdown(
        "<h6 style='text-align: center;'>原图像</h6>",
        unsafe_allow_html=True,
    )
    body_cls[1].markdown(
        "<h6 style='text-align: center;'>处理后</h6>",
        unsafe_allow_html=True,
    )

    HEIGHT = 400
    input_container = body_cls[0].container(height=HEIGHT)
    output_container = body_cls[1].container(height=HEIGHT)
    input_image_ph = input_container.empty()
    output_image_ph = output_container.empty()

    # show image
    if state.image:
        input_image_ph.image(state.image)
    if state.image_nbg:
        output_image_ph.image(state.image_nbg)

    btm_cls = st.columns(3)
    submit_btn = btm_cls[0].button(
        ":orange[:material/Cloud_Upload: **上传图片**]", use_container_width=True
    )
    process_ph = btm_cls[1].empty()
    process_btn = process_ph.button(
        ":rainbow[:material/Hourglass_Empty: **进行处理**]", use_container_width=True
    )
    download_btn = btm_cls[2].button(
        ":green[:material/Download_2: **下载图片**]",
        use_container_width=True,
        disabled=not state.mask,
    )

    if DEVICE == "CPU":
        cpu_percent = psutil.cpu_percent(interval=0.5)
        cpustates = (
            f"🌟CPU运行会比较慢,请耐心等待一下~🫡(CPU利用率:{cpu_percent:.3f}%)"
        )
        st.caption(
            f'<p style="text-align: center;">{cpustates}</p>',
            unsafe_allow_html=True,
        )

    if submit_btn:
        upload_image(input_image_ph, output_image_ph)

    if process_btn:
        if state.image:
            with output_image_ph.container(), st.spinner(f"正在处理中({DEVICE})..."):
                mask, image_nbg = process(state.image)

            state.image_nbg = image_nbg
            state.mask = mask
            output_container.image(image_nbg)
            st.rerun()
        else:
            st.toast("请上传图片")

    if download_btn:
        download_image()


if __name__ == "__main__":
    main()