ciga / app.py
kimhyunwoo's picture
Create app.py
7fe5267 verified
import streamlit as st
import av # streamlit-webrtcκ°€ λΉ„λ””μ˜€ ν”„λ ˆμž„μ„ 닀루기 μœ„ν•΄ μ‚¬μš©
import cv2
import numpy as np
from ultralytics import YOLO
from streamlit_webrtc import webrtc_streamer, VideoTransformerBase, RTCConfiguration, WebRtcMode
# --- μ„€μ • ---
MODEL_PATH = 'trained_model.pt' # ν•™μŠ΅ν•œ YOLO λͺ¨λΈ 파일 경둜
CONFIDENCE_THRESHOLD = 0.4 # 객체 탐지 μ΅œμ†Œ 신뒰도 (λͺ¨λΈμ— 따라 μ‘°μ •)
SEND_ALERT_INTERVAL = 30 # λ‹΄λ°° 탐지 μ‹œ 데이터 채널 λ©”μ‹œμ§€λ₯Ό λͺ‡ ν”„λ ˆμž„λ§ˆλ‹€ 보낼지 (λ„ˆλ¬΄ 짧으면 ν΄λΌμ΄μ–ΈνŠΈ λΆ€λ‹΄)
# --- YOLO λͺ¨λΈ λ‘œλ“œ ---
@st.cache_resource # Streamlit의 캐싱 κΈ°λŠ₯ μ‚¬μš©: μ•± μ‹€ν–‰ 쀑 λͺ¨λΈμ„ ν•œ 번만 λ‘œλ“œ
def load_yolo_model(model_path):
try:
model = YOLO(model_path)
if hasattr(model, 'model') and model.model is not None:
st.success(f"YOLO λͺ¨λΈ λ‘œλ“œ 성곡: {model_path}")
return model
else:
st.error(f"YOLO λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨ λ˜λŠ” 객체 μ΄ˆκΈ°ν™” 문제: {model_path}")
st.stop()
except FileNotFoundError:
st.error(f"였λ₯˜: λͺ¨λΈ 파일이 μ—†μŠ΅λ‹ˆλ‹€. '{model_path}' 경둜λ₯Ό ν™•μΈν•˜μ„Έμš”.")
st.stop() # λͺ¨λΈ 파일 μ—†μœΌλ©΄ μ•± 쀑지
except Exception as e:
st.error(f"YOLO λͺ¨λΈ λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
st.stop() # λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨ μ‹œ μ•± 쀑지
model = load_yolo_model(MODEL_PATH)
# --- Streamlit-WebRTCλ₯Ό μœ„ν•œ λΉ„λ””μ˜€ λ³€ν™˜ 클래슀 ---
# 이 클래슀의 μΈμŠ€ν„΄μŠ€κ°€ λΉ„λ””μ˜€ ν”„λ ˆμž„λ§ˆλ‹€ ν˜ΈμΆœλ©λ‹ˆλ‹€.
class YOLOVideoTransformer(VideoTransformerBase):
# __init__에 data_channel 인자λ₯Ό μΆ”κ°€ν•˜μ—¬ ν΄λΌμ΄μ–ΈνŠΈμ™€ 톡신
def __init__(self, model, confidence_thresh, send_interval, data_channel):
self.model = model
self.confidence_thresh = confidence_thresh
self.send_interval = send_interval
self._data_channel = data_channel # ν΄λΌμ΄μ–ΈνŠΈμ™€ 톡신할 데이터 채널 객체
self.detected_in_prev_frame = False # 이전 ν”„λ ˆμž„μ—μ„œ νƒμ§€λ˜μ—ˆλŠ”μ§€ μ—¬λΆ€
self.frame_counter = 0 # ν”„λ ˆμž„ μΉ΄μš΄ν„°
# 각 λΉ„λ””μ˜€ ν”„λ ˆμž„μ„ μ²˜λ¦¬ν•˜λŠ” λ©”μ„œλ“œ (비동기 ν•¨μˆ˜λ‘œ μ •μ˜)
async def recv(self, frame: av.VideoFrame) -> av.VideoFrame:
self.frame_counter += 1
# AV ν”„λ ˆμž„μ„ OpenCV(numpy) μ΄λ―Έμ§€λ‘œ λ³€ν™˜
img = frame.to_ndarray(format="bgr24")
# YOLOv8 λͺ¨λΈλ‘œ 객체 탐지
# verbose=False: μ½˜μ†”μ— 탐지 κ²°κ³Ό 좜λ ₯ μ•ˆ 함
results = self.model(img, conf=self.confidence_thresh, verbose=False)
cigarette_detected_in_current_frame = False
# κ²°κ³Όμ—μ„œ 'cigarette' 객체가 νƒμ§€λ˜μ—ˆλŠ”μ§€ 확인
if results and len(results) > 0:
for box in results[0].boxes:
class_id = int(box.cls[0])
confidence = float(box.conf[0])
class_name = self.model.names[class_id]
if class_name == 'cigarette' and confidence >= self.confidence_thresh:
cigarette_detected_in_current_frame = True
break # ν•˜λ‚˜λΌλ„ νƒμ§€λ˜λ©΄ 더 이상 확인할 ν•„μš” μ—†μŒ
# --- ν΄λΌμ΄μ–ΈνŠΈ μ†Œλ¦¬ μ•Œλ¦Ό 둜직 (데이터 채널 μ‚¬μš©) ---
# λ‹΄λ°°κ°€ ν˜„μž¬ ν”„λ ˆμž„μ—μ„œ νƒμ§€λ˜μ—ˆκ³ , 데이터 채널이 μ—΄λ € 있으며,
# λ©”μ‹œμ§€ 전솑 간격에 λ„λ‹¬ν–ˆμ„ λ•Œ λ©”μ‹œμ§€ 전솑
if cigarette_detected_in_current_frame and self._data_channel and self._data_channel.readyState == "open":
if not self.detected_in_prev_frame or self.frame_counter % self.send_interval == 0:
# print("Sending DETECT_CIGARETTE message to client...") # 디버그 좜λ ₯
await self._data_channel.send("DETECT_CIGARETTE") # ν΄λΌμ΄μ–ΈνŠΈλ‘œ λ©”μ‹œμ§€ 전솑
# λ‹€μŒ ν”„λ ˆμž„μ„ μœ„ν•΄ ν˜„μž¬ 탐지 μƒνƒœ μ €μž₯
self.detected_in_prev_frame = cigarette_detected_in_current_frame
# νƒμ§€λœ κ²°κ³Όλ₯Ό 이미지에 ν‘œμ‹œ (λ°”μš΄λ”© λ°•μŠ€, 라벨 λ“±)
annotated_img = results[0].plot()
# 처리된 이미지(numpy)λ₯Ό λ‹€μ‹œ AV ν”„λ ˆμž„μœΌλ‘œ λ³€ν™˜ν•˜μ—¬ λ°˜ν™˜
return av.VideoFrame.from_ndarray(annotated_img, format="bgr24")
# --- ν΄λΌμ΄μ–ΈνŠΈ μΈ‘ JavaScript μ½”λ“œ (데이터 채널 λ©”μ‹œμ§€ μˆ˜μ‹  및 μ†Œλ¦¬ μž¬μƒ) ---
# webrtc_streamer의 on_data_channel μΈμžμ— 전달될 JavaScript ν•¨μˆ˜ μ •μ˜
# 이 ν•¨μˆ˜λŠ” data channel 객체λ₯Ό 인자둜 λ°›μŠ΅λ‹ˆλ‹€.
# λ©”μ‹œμ§€λ₯Ό λ°›μœΌλ©΄ μ›Ή μ˜€λ””μ˜€ API둜 μ‚¬μΈνŒŒλ₯Ό μƒμ„±ν•˜μ—¬ μž¬μƒν•©λ‹ˆλ‹€.
JS_CLIENT_SOUND_SCRIPT = """
(channel) => {
// μ˜€λ””μ˜€ μ»¨ν…μŠ€νŠΈ 생성 (클릭 λ“± μ‚¬μš©μž μƒν˜Έμž‘μš© 후에 생성해야 ν•  수 있음)
// webrtc_streamer μ‹œμž‘ λ²„νŠΌμ΄ 이미 μƒν˜Έμž‘μš© 역할을 ν•©λ‹ˆλ‹€.
const audioContext = new (window.AudioContext || window.webkitAudioContext)();
let lastPlayTime = 0; // λ§ˆμ§€λ§‰ μ†Œλ¦¬ μž¬μƒ μ‹œκ°„ (ms)
const playCooldown = 200; // μ†Œλ¦¬ μž¬μƒ μ΅œμ†Œ 간격 (ms)
// μ‚¬μΈνŒŒ μ†Œλ¦¬λ₯Ό μž¬μƒν•˜λŠ” ν•¨μˆ˜
const playSineWaveAlert = () => {
const now = audioContext.currentTime * 1000; // ν˜„μž¬ μ‹œκ°„μ„ λ°€λ¦¬μ΄ˆλ‘œ λ³€ν™˜
if (now - lastPlayTime < playCooldown) {
// console.log("Cooldown active. Skipping sound."); // 디버그 좜λ ₯
return; // μΏ¨λ‹€μš΄ 쀑이면 μž¬μƒν•˜μ§€ μ•ŠμŒ
}
lastPlayTime = now; // λ§ˆμ§€λ§‰ μž¬μƒ μ‹œκ°„ μ—…λ°μ΄νŠΈ
try {
const oscillator = audioContext.createOscillator();
const gainNode = audioContext.createGain();
oscillator.type = 'sine'; // μ‚¬μΈνŒŒ
oscillator.frequency.setValueAtTime(600, audioContext.currentTime); // 주파수 (예: 600 Hz)
gainNode.gain.setValueAtTime(0.3, audioContext.currentTime); // λ³Όλ₯¨ (0.0 ~ 1.0)
oscillator.connect(gainNode);
gainNode.connect(audioContext.destination);
oscillator.start();
oscillator.stop(audioContext.currentTime + 0.2); // 0.2초 μž¬μƒ
// console.log("Playing sine wave sound."); // 디버그 좜λ ₯
} catch (e) {
console.error("Error playing sine wave:", e);
}
};
// 데이터 μ±„λ„λ‘œλΆ€ν„° λ©”μ‹œμ§€λ₯Ό μˆ˜μ‹ ν–ˆμ„ λ•Œ 싀행될 콜백 ν•¨μˆ˜
channel.onmessage = (event) => {
// console.log("Received message:", event.data); // μˆ˜μ‹  λ©”μ‹œμ§€ 확인
if (event.data === "DETECT_CIGARETTE") {
// μ„œλ²„μ—μ„œ λ‹΄λ°° 탐지 λ©”μ‹œμ§€λ₯Ό λ°›μœΌλ©΄ μ†Œλ¦¬ μž¬μƒ
playSineWaveAlert();
}
};
// 데이터 채널이 열렸을 λ•Œ
channel.onopen = () => {
console.log("Data channel opened!");
};
// 데이터 채널이 λ‹«ν˜”μ„ λ•Œ
channel.onclose = () => {
console.log("Data channel closed.");
};
// 데이터 채널 μ—λŸ¬ λ°œμƒ μ‹œ
channel.onerror = (error) => {
console.error("Data channel error:", error);
};
}
"""
# --- Streamlit μ•± λ ˆμ΄μ•„μ›ƒ ꡬ성 ---
st.title("🚬 μ‹€μ‹œκ°„ λ‹΄λ°° 탐지 μ›Ή μ• ν”Œλ¦¬μΌ€μ΄μ…˜ (ν΄λΌμ΄μ–ΈνŠΈ μ†Œλ¦¬)")
st.write("""
μ›ΉμΊ  ν”Όλ“œλ₯Ό 톡해 λ‹΄λ°° 객체λ₯Ό μ‹€μ‹œκ°„μœΌλ‘œ νƒμ§€ν•˜κ³  μ˜μƒμ— ν‘œμ‹œν•©λ‹ˆλ‹€.
λ‹΄λ°°κ°€ νƒμ§€λ˜λ©΄ **μ‚¬μš©μžμ˜ λΈŒλΌμš°μ €**μ—μ„œ μ•Œλ¦Ό μ†Œλ¦¬(μ‚¬μΈνŒŒ)κ°€ μž¬μƒλ©λ‹ˆλ‹€.
**주의:**
* 이 앱은 μ‚¬μš©μžμ˜ λΈŒλΌμš°μ € μ›ΉμΊ  및 μ˜€λ””μ˜€ μž¬μƒ κΆŒν•œμ΄ ν•„μš”ν•©λ‹ˆλ‹€. λΈŒλΌμš°μ € μš”μ²­ μ‹œ ν—ˆμš©ν•΄μ£Όμ„Έμš”.
* λ„€νŠΈμ›Œν¬ μƒνƒœ 및 컴퓨터 μ„±λŠ₯에 따라 μ˜μƒ μ²˜λ¦¬μ— 지연이 λ°œμƒν•  수 μžˆμŠ΅λ‹ˆλ‹€.
* `trained_model.pt` 파일이 슀크립트 파일과 같은 디렉토리에 μžˆλŠ”μ§€ ν™•μΈν•˜μ„Έμš”.
""")
st.write("---")
st.subheader("μ›ΉμΊ  슀트림 및 λ‹΄λ°° 탐지 κ²°κ³Ό")
# RTC μ„€μ • (NAT 톡과λ₯Ό μœ„ν•΄ ν•„μš”, Google STUN μ„œλ²„ μ‚¬μš©)
# λŒ€λΆ€λΆ„μ˜ 경우 κΈ°λ³Έ μ„€μ •μœΌλ‘œ μΆ©λΆ„ν•˜λ‚˜, λͺ…μ‹œμ μœΌλ‘œ μ„€μ •ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
rtc_configuration = RTCConfiguration({"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]})
# Streamlit-WebRTC μ»΄ν¬λ„ŒνŠΈ μΆ”κ°€
webrtc_ctx = webrtc_streamer(
key="yolo-detection-client-sound", # 고유 ν‚€
mode=WebRtcMode.SENDRECV, # λΉ„λ””μ˜€λ₯Ό 보내고 (SEND) μ„œλ²„μ—μ„œ 처리된 λΉ„λ””μ˜€λ₯Ό λ‹€μ‹œ λ°›μŒ (RECV)
video_processor_factory=lambda: YOLOVideoTransformer( # λΉ„λ””μ˜€ λ³€ν™˜ 클래슀 νŒ©ν† λ¦¬
model=model,
confidence_thresh=CONFIDENCE_THRESHOLD,
send_interval=SEND_ALERT_INTERVAL,
# NOTE: video_processor_factoryκ°€ lambda ν•¨μˆ˜λ‘œ μ‚¬μš©λ  λ•Œ,
# webrtc_streamer λ‚΄λΆ€μ μœΌλ‘œ μƒμ„±λœ data_channel 객체가 VideoTransformer μΈμŠ€ν„΄μŠ€μ— μ „λ‹¬λ©λ‹ˆλ‹€.
# λͺ…μ‹œμ μœΌλ‘œ lambda 인자둜 channel을 λ°›μ§€ μ•Šμ•„λ„ λ©λ‹ˆλ‹€.
# (webrtc_streamer의 κ΅¬ν˜„ 방식에 따라 λ‹€λ₯Ό 수 μžˆμœΌλ―€λ‘œ λ¬Έμ„œ 확인 ν•„μš”)
# μ΅œμ‹  λ²„μ „μ—μ„œλŠ” __init__에 data_channel=data_channel ν˜•νƒœλ‘œ 전달됨
data_channel=None # μ΄ˆκΈ°κ°’μ€ None, webrtc_streamerκ°€ μΈμŠ€ν„΄μŠ€ 생성 μ‹œ μ‹€μ œ 객체 μ£Όμž…
# -> μ•„λ‹˜, lambda νŒ©ν† λ¦¬κ°€ 인자λ₯Ό 받도둝 λ³€κ²½ν•΄μ•Ό 함
# lambda channel: YOLOVideoTransformer(..., data_channel=channel) 이 더 정확함
),
rtc_configuration=rtc_configuration,
media_stream_constraints={"video": True, "audio": False}, # μ›ΉμΊ  λΉ„λ””μ˜€λ§Œ μ‚¬μš©
async_processing=True, # λΉ„λ””μ˜€ 처리λ₯Ό λΉ„λ™κΈ°λ‘œ μ‹€ν–‰
on_data_channel=JS_CLIENT_SOUND_SCRIPT # 데이터 채널 κ΄€λ ¨ ν΄λΌμ΄μ–ΈνŠΈ JS μ½”λ“œ
)
# μœ„ video_processor_factory lambda 뢀뢄을 λ‹€μŒκ³Ό 같이 λͺ…μ‹œμ μœΌλ‘œ data_channel을 받도둝 μˆ˜μ •ν•©λ‹ˆλ‹€.
# lambda channel: YOLOVideoTransformer(
# model=model,
# confidence_thresh=CONFIDENCE_THRESHOLD,
# send_interval=SEND_ALERT_INTERVAL,
# data_channel=channel # 데이터 채널 객체 전달
# ),
st.write("---")
st.info("μ›ΉμΊ  μŠ€νŠΈλ¦Όμ„ μ‹œμž‘ν•˜λ©΄ λΈŒλΌμš°μ €μ—μ„œ λ‹΄λ°° 탐지 μ‹œ μ•Œλ¦Ό μ†Œλ¦¬κ°€ μž¬μƒλ©λ‹ˆλ‹€.")