awacke1 commited on
Commit
e33335a
·
verified ·
1 Parent(s): 23cb5b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -23
app.py CHANGED
@@ -13,65 +13,76 @@ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
13
 
14
  async def recognize_speech(websocket):
15
  async for message in websocket:
16
- wf, samplerate = sf.read(io.BytesIO(message))
17
- input_values = tokenizer(wf, return_tensors="pt").input_values
18
- with torch.no_grad():
19
- logits = model(input_values).logits
20
-
21
- predicted_ids = torch.argmax(logits, dim=-1)
22
- transcription = tokenizer.decode(predicted_ids[0])
23
- await websocket.send(transcription)
 
 
 
 
 
 
 
 
24
 
25
  async def main_logic():
26
  async with websockets.serve(recognize_speech, "localhost", 8000):
27
  await asyncio.Future() # run forever
28
 
29
- # Create the streamlit interface
30
- st.title("Real-Time ASR with Transformers.js")
31
 
32
- # The script can't be run via "streamlit run" because that hangs asyncio loop
33
  st.markdown("""
34
  <script>
35
  const handleAudio = async (stream) => {
36
  const websocket = new WebSocket('ws://localhost:8000');
37
- const mediaRecorder = new MediaRecorder(stream, {mimeType: 'audio/webm'});
38
  const audioChunks = [];
39
 
40
  mediaRecorder.addEventListener("dataavailable", event => {
41
- console.log('dataavailable:', event.data);
42
  audioChunks.push(event.data);
43
- websocket.send(event.data);
 
 
 
 
44
  });
45
 
46
  websocket.onmessage = (event) => {
47
  const transcription = event.data;
48
  const transcriptionDiv = document.getElementById("transcription");
49
- transcriptionDiv.innerHTML = transcriptionDiv.innerHTML + transcription + "<br/>";
50
- console.log('Received:', transcription);
51
  };
52
 
53
- mediaRecorder.start(1000);
54
-
55
  websocket.onopen = () => {
56
- console.log('Connected to WebSocket');
57
  };
58
 
59
  websocket.onerror = (error) => {
60
- console.error('WebSocket Error:', error);
61
  };
62
 
63
  websocket.onclose = () => {
64
- console.log('WebSocket Closed');
65
  };
 
 
66
  };
67
 
68
  navigator.mediaDevices.getUserMedia({ audio: true })
69
  .then(handleAudio)
70
- .catch(error => console.error('getUserMedia Error:', error));
71
  </script>
72
 
73
  <div id="transcription">Your transcriptions will appear here:</div>
74
  """, unsafe_allow_html=True)
75
 
 
76
  if __name__ == "__main__":
77
- asyncio.run(main_logic())
 
13
 
14
  async def recognize_speech(websocket):
15
  async for message in websocket:
16
+ try:
17
+ # Read audio data from message
18
+ wf, samplerate = sf.read(io.BytesIO(message))
19
+ # Tokenize input values
20
+ input_values = tokenizer(wf, return_tensors="pt").input_values
21
+ # Predict logits
22
+ with torch.no_grad():
23
+ logits = model(input_values).logits
24
+ # Decode predictions
25
+ predicted_ids = torch.argmax(logits, dim=-1)
26
+ transcription = tokenizer.decode(predicted_ids[0])
27
+ # Send transcription back to the client
28
+ await websocket.send(transcription)
29
+ except Exception as e:
30
+ print(f"Error in recognize_speech: {e}")
31
+ await websocket.send("Error processing audio data.")
32
 
33
  async def main_logic():
34
  async with websockets.serve(recognize_speech, "localhost", 8000):
35
  await asyncio.Future() # run forever
36
 
37
+ # Streamlit interface
38
+ st.title("Real-Time ASR with Transformers")
39
 
40
+ # WebSocket script for the frontend
41
  st.markdown("""
42
  <script>
43
  const handleAudio = async (stream) => {
44
  const websocket = new WebSocket('ws://localhost:8000');
45
+ const mediaRecorder = new MediaRecorder(stream, { mimeType: 'audio/webm' });
46
  const audioChunks = [];
47
 
48
  mediaRecorder.addEventListener("dataavailable", event => {
 
49
  audioChunks.push(event.data);
50
+ });
51
+
52
+ mediaRecorder.addEventListener("stop", () => {
53
+ const audioBlob = new Blob(audioChunks);
54
+ websocket.send(audioBlob);
55
  });
56
 
57
  websocket.onmessage = (event) => {
58
  const transcription = event.data;
59
  const transcriptionDiv = document.getElementById("transcription");
60
+ transcriptionDiv.innerHTML += `<div>${transcription}</div>`;
 
61
  };
62
 
 
 
63
  websocket.onopen = () => {
64
+ console.log('WebSocket connection established.');
65
  };
66
 
67
  websocket.onerror = (error) => {
68
+ console.error('WebSocket error:', error);
69
  };
70
 
71
  websocket.onclose = () => {
72
+ console.log('WebSocket connection closed.');
73
  };
74
+
75
+ mediaRecorder.start(1000);
76
  };
77
 
78
  navigator.mediaDevices.getUserMedia({ audio: true })
79
  .then(handleAudio)
80
+ .catch(error => console.error('Error accessing media devices.', error));
81
  </script>
82
 
83
  <div id="transcription">Your transcriptions will appear here:</div>
84
  """, unsafe_allow_html=True)
85
 
86
+ # To run the WebSocket server
87
  if __name__ == "__main__":
88
+ asyncio.get_event_loop().run_until_complete(main_logic())