import os import shutil from PIL import Image import ffmpeg import streamlit as st import torch from transformers import AutoProcessor, AutoModel from src.lstm_model import LSTMNetwork from src.frames import extract_frames, convert_to_mp4 # Required dictionary idx_to_class = {0: 'cover', 1: 'defense', 2: 'flick', 3: 'hook', 4: 'late_cut', 5: 'lofted', 6: 'pull', 7: 'square_cut', 8: 'straight', 9: 'sweep'} class_label_mapping = {'cover': 0, 'defense': 1, 'flick': 2, 'hook': 3, 'late_cut': 4, 'lofted': 5, 'pull': 6, 'square_cut': 7, 'straight': 8, 'sweep': 9} # Definig the paths CLIP_MODEL_PATH = "clip-cricket-classifier.pt" SIGLIP_MODEL_PATH = "siglip-cricket-classifier.pt" CLIP_MODEL_ID = "openai/clip-vit-base-patch32" SIGLIP_MODEL_ID = "google/siglip-base-patch16-224" def embeddings_creators(MODEL_ID): embedding_processor = AutoProcessor.from_pretrained(MODEL_ID) embedding_model = AutoModel.from_pretrained(MODEL_ID) embedding_model.to(device) return embedding_processor, embedding_model def load_model(MODEL_PATH): if MODEL_PATH == CLIP_MODEL_PATH: input_size = 512 elif MODEL_PATH == SIGLIP_MODEL_PATH: input_size = 768 else: raise ValueError(f"Invalid model path: {MODEL_PATH}") model = LSTMNetwork(input_size=input_size, hidden_size=256, num_classes=10).to(device) model.load_state_dict(torch.load(MODEL_PATH)) return model # device device = 'cuda' if torch.cuda.is_available() else 'cpu' def app(): st.image("assets/banner.png") st.title("Cricket Shot Classifier", anchor=False) model_choice = st.radio("Select a model", ["None", "CLIP", "SIGLIP"]) if model_choice == "None": st.stop() st.write("Please select a model") if model_choice == "CLIP": embedding_processor, embedding_model = embeddings_creators(CLIP_MODEL_ID) model = load_model(CLIP_MODEL_PATH) elif model_choice == "SIGLIP": embedding_processor, embedding_model = embeddings_creators(SIGLIP_MODEL_ID) model = load_model(SIGLIP_MODEL_PATH) # List sample videos from assets folder sample_videos = [f for f in os.listdir("assets") if f.endswith(('.avi'))] if not sample_videos: st.error("No sample videos found in assets folder") st.stop() selected_video = st.selectbox("Select a sample video", sample_videos) video_path = os.path.join("assets", selected_video) save_directory = './demo' os.makedirs(save_directory, exist_ok=True) new_video_path = f"{save_directory}/{selected_video}" shutil.copy2(video_path, new_video_path) final_video_path = f"{save_directory}/{os.path.splitext(os.path.basename(new_video_path))[0]}.mp4" if not new_video_path.lower().endswith('.mp4'): convert_to_mp4(new_video_path, final_video_path) else: final_video_path = new_video_path st.video(final_video_path) frames_dir = f"{save_directory}/frames" os.makedirs(frames_dir, exist_ok=True) extract_frames(final_video_path, frames_dir) st.write("Frames extracted from the video.") inference_paths = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) if f.endswith(('.jpg', '.jpeg', '.png'))] inference_images = [Image.open(path).convert("RGB") for path in inference_paths] tokens = embedding_processor( text=None, images=inference_images, return_tensors="pt" ).to(device) inference_embeddings = embedding_model.get_image_features(**tokens) with torch.no_grad(): output = model(inference_embeddings.unsqueeze(0)) prob = output.softmax(dim=1) _, indices = torch.sort(prob[0], descending=True) for idx in indices: i = idx.item() st.write(f"Prediction: {idx_to_class[i]}") st.progress(int(prob[0][i].item() * 100)) try: shutil.rmtree(frames_dir) os.remove(new_video_path) os.remove(final_video_path) print(f"Folder '{frames_dir}' and its contents have been deleted.") except Exception as e: print(f"Error while deleting folder '{frames_dir}': {e}") if __name__ == "__main__": app()