Websockets / app.py
awacke1's picture
Create app.py
7be7ac1 verified
raw
history blame
2.5 kB
import asyncio
import websockets
import streamlit as st
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
import numpy as np
import torch
import soundfile as sf
import io
# Load pre-trained model and tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
async def recognize_speech(websocket):
async for message in websocket:
wf, samplerate = sf.read(io.BytesIO(message))
input_values = tokenizer(wf, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.decode(predicted_ids[0])
await websocket.send(transcription)
async def main_logic():
async with websockets.serve(recognize_speech, "localhost", 8000):
await asyncio.Future() # run forever
# Create the streamlit interface
st.title("Real-Time ASR with Transformers.js")
# The script can't be run via "streamlit run" because that hangs asyncio loop
st.markdown("""
<script>
const handleAudio = async (stream) => {
const websocket = new WebSocket('ws://localhost:8000');
const mediaRecorder = new MediaRecorder(stream, {mimeType: 'audio/webm'});
const audioChunks = [];
mediaRecorder.addEventListener("dataavailable", event => {
console.log('dataavailable:', event.data);
audioChunks.push(event.data);
websocket.send(event.data);
});
websocket.onmessage = (event) => {
const transcription = event.data;
const transcriptionDiv = document.getElementById("transcription");
transcriptionDiv.innerHTML = transcriptionDiv.innerHTML + transcription + "<br/>";
console.log('Received:', transcription);
};
mediaRecorder.start(1000);
websocket.onopen = () => {
console.log('Connected to WebSocket');
};
websocket.onerror = (error) => {
console.error('WebSocket Error:', error);
};
websocket.onclose = () => {
console.log('WebSocket Closed');
};
};
navigator.mediaDevices.getUserMedia({ audio: true })
.then(handleAudio)
.catch(error => console.error('getUserMedia Error:', error));
</script>
<div id="transcription">Your transcriptions will appear here:</div>
""", unsafe_allow_html=True)