import io import ssl import base64 import torch import streamlit as st import urllib.request import psutil from PIL import Image from process import process st.set_page_config("Ai抠图(RMBG 2.0)", layout="wide") st.markdown( """""", unsafe_allow_html=True, ) state = st.session_state if "image" not in state: state.image = "" if "image_nbg" not in state: state.image_nbg = "" if "mask" not in state: state.mask = "" if "filename" not in state: state.filename = "" if "image_stream" not in state: state.image_stream = None if "read_file_once" not in state: state.read_file_once = 0 IMAGE_FORMATS = ("jpg", "png", "jpeg", "JPG", "PNG", "JPEG") DEVICE = "GPU" if torch.cuda.is_available() else "CPU" @st.dialog("上传图片") def upload_image(input_image_ph, output_image_ph): # 网络图片 st.markdown("**图片链接**", help="填写网络图片地址") cls = st.columns([0.8, 0.2]) url = cls[0].text_input( "xxx", placeholder="url/base64...", label_visibility="collapsed" ) if cls[1].button( "读取", use_container_width=True, # disabled=not url or not url.startswith(("https://", "data:image/")), ): try: if url.startswith(("https://", "http://")): content = ssl._create_unverified_context() with urllib.request.urlopen(url, context=content) as response: image_data = response.read() state.image_stream = io.BytesIO(image_data) name = "image." + url.rsplit(".", 1)[-1] elif url.startswith("data:image/"): pfix, base64_data = url.split(",", 1) state.image_stream = io.BytesIO(base64.b64decode(base64_data)) name = "image." + pfix[11:-7] else: st.warning(":red[请输入有效的图片链接]") except Exception as e: st.warning(f":red[**读取图片失败,请保存到本地后上传**]") return # 本地图片 def _cb(): state.read_file_once = 1 st.markdown("**上传图片**") file = st.file_uploader( "xxx", accept_multiple_files=False, type=IMAGE_FORMATS, label_visibility="collapsed", on_change=_cb, key="upload_key", ) if state.read_file_once and file: state.image_stream = io.BytesIO(file.getvalue()) name = file.name state.read_file_once = 0 if state.image_stream is not None: try: image = Image.open(state.image_stream) state.image = image state.mask = "" state.image_nbg = "" state.filename = name input_image_ph.image(image) output_image_ph.empty() st.success(":rainbow[**上传成功**]") except Exception as e: st.warning(f":red[处理图片出错 >> {e}]") state.image_stream = None @st.dialog("下载图片") def download_image(): if not state.mask or not state.image_nbg: st.warning("请上传图片") else: with st.spinner("正在处理中..."): buffer1 = io.BytesIO() state.mask.save(buffer1, format="PNG") buffer2 = io.BytesIO() state.image_nbg.save(buffer2, format="PNG") name = state.filename.rsplit(".", 1)[0] + "-mask.png" st.download_button( "下载掩码图片", data=buffer1.getvalue(), file_name=name, use_container_width=True, disabled=not state.mask, ) name = state.filename.rsplit(".", 1)[0] + "-no-bg.png" st.download_button( "下载前景图片", data=buffer2.getvalue(), file_name=name, use_container_width=True, disabled=not state.image_nbg, ) def main(): st.markdown( '
{cpustates}
', unsafe_allow_html=True, ) if submit_btn: upload_image(input_image_ph, output_image_ph) if process_btn: if state.image: with output_image_ph.container(), st.spinner(f"正在处理中({DEVICE})..."): mask, image_nbg = process(state.image) state.image_nbg = image_nbg state.mask = mask output_container.image(image_nbg) st.rerun() else: st.toast("请上传图片") if download_btn: download_image() if __name__ == "__main__": main()