Spaces:
Running
Running
import streamlit as st | |
import time | |
import requests | |
from streamlit.components.v1 import html | |
import os | |
from dotenv import load_dotenv | |
# Voice input dependencies | |
import torchaudio | |
import numpy as np | |
import torch | |
from io import BytesIO | |
import hashlib | |
from audio_recorder_streamlit import audio_recorder | |
from transformers import pipeline | |
###################################### | |
# Voice Input Helper Functions | |
###################################### | |
def load_voice_model(): | |
return pipeline("automatic-speech-recognition", model="openai/whisper-base") | |
def process_audio(audio_bytes): | |
waveform, sample_rate = torchaudio.load(BytesIO(audio_bytes)) | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
waveform = resampler(waveform) | |
return {"raw": waveform.numpy().squeeze(), "sampling_rate": 16000} | |
def get_voice_transcription(state_key): | |
if state_key not in st.session_state: | |
st.session_state[state_key] = "" | |
audio_bytes = audio_recorder( | |
key=state_key + "_audio", | |
pause_threshold=0.8, | |
text="๐๏ธ Speak your message", | |
recording_color="#e8b62c", | |
neutral_color="#6aa36f" | |
) | |
if audio_bytes: | |
current_hash = hashlib.md5(audio_bytes).hexdigest() | |
last_hash_key = state_key + "_last_hash" | |
if st.session_state.get(last_hash_key, "") != current_hash: | |
st.session_state[last_hash_key] = current_hash | |
try: | |
audio_input = process_audio(audio_bytes) | |
whisper = load_voice_model() | |
transcribed_text = whisper(audio_input)["text"] | |
st.info(f"๐ Transcribed: {transcribed_text}") | |
st.session_state[state_key] += (" " + transcribed_text).strip() | |
st.experimental_rerun() | |
except Exception as e: | |
st.error(f"Voice input error: {str(e)}") | |
return st.session_state[state_key] | |
###################################### | |
# Game Functions & Styling | |
###################################### | |
def get_help_agent(): | |
return pipeline("conversational", model="facebook/blenderbot-400M-distill") | |
def inject_custom_css(): | |
st.markdown(""" | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); | |
* { font-family: 'Inter', sans-serif; } | |
.title { font-size: 2.8rem !important; font-weight: 800 !important; | |
background: linear-gradient(45deg, #6C63FF, #3B82F6); | |
-webkit-background-clip: text; -webkit-text-fill-color: transparent; | |
text-align: center; margin: 1rem 0; } | |
.subtitle { font-size: 1.1rem; text-align: center; color: #64748B; margin-bottom: 2.5rem; } | |
.question-box { background: white; border-radius: 20px; padding: 2rem; margin: 1.5rem 0; | |
box-shadow: 0 10px 25px rgba(0,0,0,0.08); border: 1px solid #e2e8f0; color: black; } | |
.input-box { background: white; border-radius: 12px; padding: 1.5rem; margin: 1rem 0; | |
box-shadow: 0 4px 6px rgba(0,0,0,0.05); } | |
.stTextInput input { border: 2px solid #e2e8f0 !important; border-radius: 10px !important; | |
padding: 12px 16px !important; } | |
button { background: linear-gradient(45deg, #6C63FF, #3B82F6) !important; | |
color: white !important; border-radius: 10px !important; | |
padding: 12px 24px !important; font-weight: 600; } | |
.final-reveal { font-size: 2.8rem; | |
background: linear-gradient(45deg, #6C63FF, #3B82F6); | |
-webkit-background-clip: text; -webkit-text-fill-color: transparent; | |
text-align: center; margin: 2rem 0; font-weight: 800; } | |
</style> | |
""", unsafe_allow_html=True) | |
def show_confetti(): | |
html(""" | |
<canvas id="confetti-canvas" class="confetti"></canvas> | |
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/confetti.browser.min.js"></script> | |
<script> | |
const count = 200; | |
const defaults = { origin: { y: 0.7 }, zIndex: 1050 }; | |
function fire(particleRatio, opts) { | |
confetti(Object.assign({}, defaults, opts, { | |
particleCount: Math.floor(count * particleRatio) | |
})); | |
} | |
fire(0.25, { spread: 26, startVelocity: 55 }); | |
fire(0.2, { spread: 60 }); | |
fire(0.35, { spread: 100, decay: 0.91, scalar: 0.8 }); | |
fire(0.1, { spread: 120, startVelocity: 25, decay: 0.92, scalar: 1.2 }); | |
fire(0.1, { spread: 120, startVelocity: 45 }); | |
</script> | |
""") | |
def ask_llama(conversation_history, category, is_final_guess=False): | |
api_url = "https://api.groq.com/openai/v1/chat/completions" | |
headers = { | |
"Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}", | |
"Content-Type": "application/json" | |
} | |
system_prompt = f"""You're playing 20 questions to guess a {category}. Rules: | |
1. Ask strategic, non-repeating yes/no questions to narrow down. | |
2. Use all previous answers smartly. | |
3. If you're 80%+ sure, say: Final Guess: [your guess] | |
4. For places: ask about continent, country, landmarks, etc. | |
5. For people: ask if real, profession, gender, etc. | |
6. For objects: ask about use, size, material, etc.""" | |
prompt = f"""Based on these answers about a {category}, provide ONLY your final guess with no extra text: | |
{conversation_history}""" if is_final_guess else "Ask your next smart yes/no question." | |
messages = [{"role": "system", "content": system_prompt}] | |
messages += conversation_history | |
messages.append({"role": "user", "content": prompt}) | |
data = { | |
"model": "llama-3-70b-8192", | |
"messages": messages, | |
"temperature": 0.8, | |
"max_tokens": 100 | |
} | |
try: | |
res = requests.post(api_url, headers=headers, json=data) | |
res.raise_for_status() | |
return res.json()["choices"][0]["message"]["content"] | |
except Exception as e: | |
st.error(f"โ LLaMA API error: {e}") | |
return "..." | |
###################################### | |
# Main App Logic Here (UI, Game Loop) | |
###################################### | |
def main(): | |
load_dotenv() | |
inject_custom_css() | |
st.title("๐ฎ Guess It! - 20 Questions Game") | |
st.markdown("<div class='subtitle'>Think of a person, place, or object. LLaMA will try to guess it!</div>", unsafe_allow_html=True) | |
category = st.selectbox("Category of your secret:", ["Person", "Place", "Object"]) | |
if "conversation" not in st.session_state: | |
st.session_state.conversation = [] | |
st.session_state.last_bot_msg = "" | |
if st.button("๐ Restart Game"): | |
st.session_state.conversation = [] | |
st.session_state.last_bot_msg = "" | |
st.rerun() | |
if not st.session_state.conversation: | |
st.session_state.last_bot_msg = ask_llama([], category) | |
st.session_state.conversation.append({"role": "assistant", "content": st.session_state.last_bot_msg}) | |
st.markdown(f"<div class='question-box'><strong>LLaMA:</strong> {st.session_state.last_bot_msg}</div>", unsafe_allow_html=True) | |
user_input = get_voice_transcription("voice_input") or st.text_input("๐ฌ Your answer (yes/no/sometimes):") | |
if st.button("Submit Answer") and user_input: | |
st.session_state.conversation.append({"role": "user", "content": user_input}) | |
with st.spinner("Thinking..."): | |
response = ask_llama(st.session_state.conversation, category) | |
st.session_state.last_bot_msg = response | |
st.session_state.conversation.append({"role": "assistant", "content": response}) | |
st.rerun() | |
if st.button("๐ค Make Final Guess"): | |
with st.spinner("Making final guess..."): | |
final_guess = ask_llama(st.session_state.conversation, category, is_final_guess=True) | |
st.markdown(f"<div class='final-reveal'>๐คฏ Final Guess: <strong>{final_guess}</strong></div>", unsafe_allow_html=True) | |
show_confetti() | |
if __name__ == "__main__": | |
main() | |