Websockets / app.py
awacke1's picture
Update app.py
f28d146 verified
import asyncio
import websockets
import streamlit as st
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
import numpy as np
import torch
import soundfile as sf
import io
# Load pre-trained model and tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
async def recognize_speech(websocket):
async for message in websocket:
wf, samplerate = sf.read(io.BytesIO(message))
input_values = tokenizer(wf, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.decode(predicted_ids[0])
await websocket.send(transcription)
async def main_logic():
async with websockets.serve(recognize_speech, "localhost", 8000):
await asyncio.Future() # run forever
# Create the streamlit interface
st.title("Real-Time ASR with Transformers.js")
# The script can't be run via "streamlit run" because that hangs asyncio loop
st.markdown("""
<script>
const handleAudio = async (stream) => {
const websocket = new WebSocket('ws://localhost:8000');
const mediaRecorder = new MediaRecorder(stream, {mimeType: 'audio/webm'});
const audioChunks = [];
mediaRecorder.addEventListener("dataavailable", event => {
console.log('dataavailable:', event.data);
audioChunks.push(event.data);
websocket.send(event.data);
});
websocket.onmessage = (event) => {
const transcription = event.data;
const transcriptionDiv = document.getElementById("transcription");
transcriptionDiv.innerHTML = transcriptionDiv.innerHTML + transcription + "<br/>";
console.log('Received:', transcription);
};
mediaRecorder.start(1000);
websocket.onopen = () => {
console.log('Connected to WebSocket');
};
websocket.onerror = (error) => {
console.error('WebSocket Error:', error);
};
websocket.onclose = () => {
console.log('WebSocket Closed');
};
};
navigator.mediaDevices.getUserMedia({ audio: true })
.then(handleAudio)
.catch(error => console.error('getUserMedia Error:', error));
</script>
<div id="transcription">Your transcriptions will appear here:</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
asyncio.run(main_logic())