Spaces:
Running
Running
import streamlit as st | |
import av # streamlit-webrtcκ° λΉλμ€ νλ μμ λ€λ£¨κΈ° μν΄ μ¬μ© | |
import cv2 | |
import numpy as np | |
from ultralytics import YOLO | |
from streamlit_webrtc import webrtc_streamer, VideoTransformerBase, RTCConfiguration, WebRtcMode | |
# --- μ€μ --- | |
MODEL_PATH = 'trained_model.pt' # νμ΅ν YOLO λͺ¨λΈ νμΌ κ²½λ‘ | |
CONFIDENCE_THRESHOLD = 0.4 # κ°μ²΄ νμ§ μ΅μ μ λ’°λ (λͺ¨λΈμ λ°λΌ μ‘°μ ) | |
SEND_ALERT_INTERVAL = 30 # λ΄λ°° νμ§ μ λ°μ΄ν° μ±λ λ©μμ§λ₯Ό λͺ νλ μλ§λ€ 보λΌμ§ (λ무 μ§§μΌλ©΄ ν΄λΌμ΄μΈνΈ λΆλ΄) | |
# --- YOLO λͺ¨λΈ λ‘λ --- | |
# Streamlitμ μΊμ± κΈ°λ₯ μ¬μ©: μ± μ€ν μ€ λͺ¨λΈμ ν λ²λ§ λ‘λ | |
def load_yolo_model(model_path): | |
try: | |
model = YOLO(model_path) | |
if hasattr(model, 'model') and model.model is not None: | |
st.success(f"YOLO λͺ¨λΈ λ‘λ μ±κ³΅: {model_path}") | |
return model | |
else: | |
st.error(f"YOLO λͺ¨λΈ λ‘λ μ€ν¨ λλ κ°μ²΄ μ΄κΈ°ν λ¬Έμ : {model_path}") | |
st.stop() | |
except FileNotFoundError: | |
st.error(f"μ€λ₯: λͺ¨λΈ νμΌμ΄ μμ΅λλ€. '{model_path}' κ²½λ‘λ₯Ό νμΈνμΈμ.") | |
st.stop() # λͺ¨λΈ νμΌ μμΌλ©΄ μ± μ€μ§ | |
except Exception as e: | |
st.error(f"YOLO λͺ¨λΈ λ‘λ μ€ μ€λ₯ λ°μ: {e}") | |
st.stop() # λͺ¨λΈ λ‘λ μ€ν¨ μ μ± μ€μ§ | |
model = load_yolo_model(MODEL_PATH) | |
# --- Streamlit-WebRTCλ₯Ό μν λΉλμ€ λ³ν ν΄λμ€ --- | |
# μ΄ ν΄λμ€μ μΈμ€ν΄μ€κ° λΉλμ€ νλ μλ§λ€ νΈμΆλ©λλ€. | |
class YOLOVideoTransformer(VideoTransformerBase): | |
# __init__μ data_channel μΈμλ₯Ό μΆκ°νμ¬ ν΄λΌμ΄μΈνΈμ ν΅μ | |
def __init__(self, model, confidence_thresh, send_interval, data_channel): | |
self.model = model | |
self.confidence_thresh = confidence_thresh | |
self.send_interval = send_interval | |
self._data_channel = data_channel # ν΄λΌμ΄μΈνΈμ ν΅μ ν λ°μ΄ν° μ±λ κ°μ²΄ | |
self.detected_in_prev_frame = False # μ΄μ νλ μμμ νμ§λμλμ§ μ¬λΆ | |
self.frame_counter = 0 # νλ μ μΉ΄μ΄ν° | |
# κ° λΉλμ€ νλ μμ μ²λ¦¬νλ λ©μλ (λΉλκΈ° ν¨μλ‘ μ μ) | |
async def recv(self, frame: av.VideoFrame) -> av.VideoFrame: | |
self.frame_counter += 1 | |
# AV νλ μμ OpenCV(numpy) μ΄λ―Έμ§λ‘ λ³ν | |
img = frame.to_ndarray(format="bgr24") | |
# YOLOv8 λͺ¨λΈλ‘ κ°μ²΄ νμ§ | |
# verbose=False: μ½μμ νμ§ κ²°κ³Ό μΆλ ₯ μ ν¨ | |
results = self.model(img, conf=self.confidence_thresh, verbose=False) | |
cigarette_detected_in_current_frame = False | |
# κ²°κ³Όμμ 'cigarette' κ°μ²΄κ° νμ§λμλμ§ νμΈ | |
if results and len(results) > 0: | |
for box in results[0].boxes: | |
class_id = int(box.cls[0]) | |
confidence = float(box.conf[0]) | |
class_name = self.model.names[class_id] | |
if class_name == 'cigarette' and confidence >= self.confidence_thresh: | |
cigarette_detected_in_current_frame = True | |
break # νλλΌλ νμ§λλ©΄ λ μ΄μ νμΈν νμ μμ | |
# --- ν΄λΌμ΄μΈνΈ μ리 μλ¦Ό λ‘μ§ (λ°μ΄ν° μ±λ μ¬μ©) --- | |
# λ΄λ°°κ° νμ¬ νλ μμμ νμ§λμκ³ , λ°μ΄ν° μ±λμ΄ μ΄λ € μμΌλ©°, | |
# λ©μμ§ μ μ‘ κ°κ²©μ λλ¬νμ λ λ©μμ§ μ μ‘ | |
if cigarette_detected_in_current_frame and self._data_channel and self._data_channel.readyState == "open": | |
if not self.detected_in_prev_frame or self.frame_counter % self.send_interval == 0: | |
# print("Sending DETECT_CIGARETTE message to client...") # λλ²κ·Έ μΆλ ₯ | |
await self._data_channel.send("DETECT_CIGARETTE") # ν΄λΌμ΄μΈνΈλ‘ λ©μμ§ μ μ‘ | |
# λ€μ νλ μμ μν΄ νμ¬ νμ§ μν μ μ₯ | |
self.detected_in_prev_frame = cigarette_detected_in_current_frame | |
# νμ§λ κ²°κ³Όλ₯Ό μ΄λ―Έμ§μ νμ (λ°μ΄λ© λ°μ€, λΌλ²¨ λ±) | |
annotated_img = results[0].plot() | |
# μ²λ¦¬λ μ΄λ―Έμ§(numpy)λ₯Ό λ€μ AV νλ μμΌλ‘ λ³ννμ¬ λ°ν | |
return av.VideoFrame.from_ndarray(annotated_img, format="bgr24") | |
# --- ν΄λΌμ΄μΈνΈ μΈ‘ JavaScript μ½λ (λ°μ΄ν° μ±λ λ©μμ§ μμ λ° μ리 μ¬μ) --- | |
# webrtc_streamerμ on_data_channel μΈμμ μ λ¬λ JavaScript ν¨μ μ μ | |
# μ΄ ν¨μλ data channel κ°μ²΄λ₯Ό μΈμλ‘ λ°μ΅λλ€. | |
# λ©μμ§λ₯Ό λ°μΌλ©΄ μΉ μ€λμ€ APIλ‘ μ¬μΈνλ₯Ό μμ±νμ¬ μ¬μν©λλ€. | |
JS_CLIENT_SOUND_SCRIPT = """ | |
(channel) => { | |
// μ€λμ€ μ»¨ν μ€νΈ μμ± (ν΄λ¦ λ± μ¬μ©μ μνΈμμ© νμ μμ±ν΄μΌ ν μ μμ) | |
// webrtc_streamer μμ λ²νΌμ΄ μ΄λ―Έ μνΈμμ© μν μ ν©λλ€. | |
const audioContext = new (window.AudioContext || window.webkitAudioContext)(); | |
let lastPlayTime = 0; // λ§μ§λ§ μ리 μ¬μ μκ° (ms) | |
const playCooldown = 200; // μ리 μ¬μ μ΅μ κ°κ²© (ms) | |
// μ¬μΈν μ리λ₯Ό μ¬μνλ ν¨μ | |
const playSineWaveAlert = () => { | |
const now = audioContext.currentTime * 1000; // νμ¬ μκ°μ λ°λ¦¬μ΄λ‘ λ³ν | |
if (now - lastPlayTime < playCooldown) { | |
// console.log("Cooldown active. Skipping sound."); // λλ²κ·Έ μΆλ ₯ | |
return; // μΏ¨λ€μ΄ μ€μ΄λ©΄ μ¬μνμ§ μμ | |
} | |
lastPlayTime = now; // λ§μ§λ§ μ¬μ μκ° μ λ°μ΄νΈ | |
try { | |
const oscillator = audioContext.createOscillator(); | |
const gainNode = audioContext.createGain(); | |
oscillator.type = 'sine'; // μ¬μΈν | |
oscillator.frequency.setValueAtTime(600, audioContext.currentTime); // μ£Όνμ (μ: 600 Hz) | |
gainNode.gain.setValueAtTime(0.3, audioContext.currentTime); // λ³Όλ₯¨ (0.0 ~ 1.0) | |
oscillator.connect(gainNode); | |
gainNode.connect(audioContext.destination); | |
oscillator.start(); | |
oscillator.stop(audioContext.currentTime + 0.2); // 0.2μ΄ μ¬μ | |
// console.log("Playing sine wave sound."); // λλ²κ·Έ μΆλ ₯ | |
} catch (e) { | |
console.error("Error playing sine wave:", e); | |
} | |
}; | |
// λ°μ΄ν° μ±λλ‘λΆν° λ©μμ§λ₯Ό μμ νμ λ μ€νλ μ½λ°± ν¨μ | |
channel.onmessage = (event) => { | |
// console.log("Received message:", event.data); // μμ λ©μμ§ νμΈ | |
if (event.data === "DETECT_CIGARETTE") { | |
// μλ²μμ λ΄λ°° νμ§ λ©μμ§λ₯Ό λ°μΌλ©΄ μ리 μ¬μ | |
playSineWaveAlert(); | |
} | |
}; | |
// λ°μ΄ν° μ±λμ΄ μ΄λ Έμ λ | |
channel.onopen = () => { | |
console.log("Data channel opened!"); | |
}; | |
// λ°μ΄ν° μ±λμ΄ λ«νμ λ | |
channel.onclose = () => { | |
console.log("Data channel closed."); | |
}; | |
// λ°μ΄ν° μ±λ μλ¬ λ°μ μ | |
channel.onerror = (error) => { | |
console.error("Data channel error:", error); | |
}; | |
} | |
""" | |
# --- Streamlit μ± λ μ΄μμ κ΅¬μ± --- | |
st.title("π¬ μ€μκ° λ΄λ°° νμ§ μΉ μ ν리μΌμ΄μ (ν΄λΌμ΄μΈνΈ μ리)") | |
st.write(""" | |
μΉμΊ νΌλλ₯Ό ν΅ν΄ λ΄λ°° κ°μ²΄λ₯Ό μ€μκ°μΌλ‘ νμ§νκ³ μμμ νμν©λλ€. | |
λ΄λ°°κ° νμ§λλ©΄ **μ¬μ©μμ λΈλΌμ°μ **μμ μλ¦Ό μ리(μ¬μΈν)κ° μ¬μλ©λλ€. | |
**μ£Όμ:** | |
* μ΄ μ±μ μ¬μ©μμ λΈλΌμ°μ μΉμΊ λ° μ€λμ€ μ¬μ κΆνμ΄ νμν©λλ€. λΈλΌμ°μ μμ² μ νμ©ν΄μ£ΌμΈμ. | |
* λ€νΈμν¬ μν λ° μ»΄ν¨ν° μ±λ₯μ λ°λΌ μμ μ²λ¦¬μ μ§μ°μ΄ λ°μν μ μμ΅λλ€. | |
* `trained_model.pt` νμΌμ΄ μ€ν¬λ¦½νΈ νμΌκ³Ό κ°μ λλ ν 리μ μλμ§ νμΈνμΈμ. | |
""") | |
st.write("---") | |
st.subheader("μΉμΊ μ€νΈλ¦Ό λ° λ΄λ°° νμ§ κ²°κ³Ό") | |
# RTC μ€μ (NAT ν΅κ³Όλ₯Ό μν΄ νμ, Google STUN μλ² μ¬μ©) | |
# λλΆλΆμ κ²½μ° κΈ°λ³Έ μ€μ μΌλ‘ μΆ©λΆνλ, λͺ μμ μΌλ‘ μ€μ ν μ μμ΅λλ€. | |
rtc_configuration = RTCConfiguration({"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}) | |
# Streamlit-WebRTC μ»΄ν¬λνΈ μΆκ° | |
webrtc_ctx = webrtc_streamer( | |
key="yolo-detection-client-sound", # κ³ μ ν€ | |
mode=WebRtcMode.SENDRECV, # λΉλμ€λ₯Ό 보λ΄κ³ (SEND) μλ²μμ μ²λ¦¬λ λΉλμ€λ₯Ό λ€μ λ°μ (RECV) | |
video_processor_factory=lambda: YOLOVideoTransformer( # λΉλμ€ λ³ν ν΄λμ€ ν©ν 리 | |
model=model, | |
confidence_thresh=CONFIDENCE_THRESHOLD, | |
send_interval=SEND_ALERT_INTERVAL, | |
# NOTE: video_processor_factoryκ° lambda ν¨μλ‘ μ¬μ©λ λ, | |
# webrtc_streamer λ΄λΆμ μΌλ‘ μμ±λ data_channel κ°μ²΄κ° VideoTransformer μΈμ€ν΄μ€μ μ λ¬λ©λλ€. | |
# λͺ μμ μΌλ‘ lambda μΈμλ‘ channelμ λ°μ§ μμλ λ©λλ€. | |
# (webrtc_streamerμ ꡬν λ°©μμ λ°λΌ λ€λ₯Ό μ μμΌλ―λ‘ λ¬Έμ νμΈ νμ) | |
# μ΅μ λ²μ μμλ __init__μ data_channel=data_channel ννλ‘ μ λ¬λ¨ | |
data_channel=None # μ΄κΈ°κ°μ None, webrtc_streamerκ° μΈμ€ν΄μ€ μμ± μ μ€μ κ°μ²΄ μ£Όμ | |
# -> μλ, lambda ν©ν λ¦¬κ° μΈμλ₯Ό λ°λλ‘ λ³κ²½ν΄μΌ ν¨ | |
# lambda channel: YOLOVideoTransformer(..., data_channel=channel) μ΄ λ μ νν¨ | |
), | |
rtc_configuration=rtc_configuration, | |
media_stream_constraints={"video": True, "audio": False}, # μΉμΊ λΉλμ€λ§ μ¬μ© | |
async_processing=True, # λΉλμ€ μ²λ¦¬λ₯Ό λΉλκΈ°λ‘ μ€ν | |
on_data_channel=JS_CLIENT_SOUND_SCRIPT # λ°μ΄ν° μ±λ κ΄λ ¨ ν΄λΌμ΄μΈνΈ JS μ½λ | |
) | |
# μ video_processor_factory lambda λΆλΆμ λ€μκ³Ό κ°μ΄ λͺ μμ μΌλ‘ data_channelμ λ°λλ‘ μμ ν©λλ€. | |
# lambda channel: YOLOVideoTransformer( | |
# model=model, | |
# confidence_thresh=CONFIDENCE_THRESHOLD, | |
# send_interval=SEND_ALERT_INTERVAL, | |
# data_channel=channel # λ°μ΄ν° μ±λ κ°μ²΄ μ λ¬ | |
# ), | |
st.write("---") | |
st.info("μΉμΊ μ€νΈλ¦Όμ μμνλ©΄ λΈλΌμ°μ μμ λ΄λ°° νμ§ μ μλ¦Ό μλ¦¬κ° μ¬μλ©λλ€.") |