awacke1 commited on
Commit
f28d146
·
verified ·
1 Parent(s): e33335a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -34
app.py CHANGED
@@ -13,76 +13,65 @@ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
13
 
14
  async def recognize_speech(websocket):
15
  async for message in websocket:
16
- try:
17
- # Read audio data from message
18
- wf, samplerate = sf.read(io.BytesIO(message))
19
- # Tokenize input values
20
- input_values = tokenizer(wf, return_tensors="pt").input_values
21
- # Predict logits
22
- with torch.no_grad():
23
- logits = model(input_values).logits
24
- # Decode predictions
25
- predicted_ids = torch.argmax(logits, dim=-1)
26
- transcription = tokenizer.decode(predicted_ids[0])
27
- # Send transcription back to the client
28
- await websocket.send(transcription)
29
- except Exception as e:
30
- print(f"Error in recognize_speech: {e}")
31
- await websocket.send("Error processing audio data.")
32
 
33
  async def main_logic():
34
  async with websockets.serve(recognize_speech, "localhost", 8000):
35
  await asyncio.Future() # run forever
36
 
37
- # Streamlit interface
38
- st.title("Real-Time ASR with Transformers")
39
 
40
- # WebSocket script for the frontend
41
  st.markdown("""
42
  <script>
43
  const handleAudio = async (stream) => {
44
  const websocket = new WebSocket('ws://localhost:8000');
45
- const mediaRecorder = new MediaRecorder(stream, { mimeType: 'audio/webm' });
46
  const audioChunks = [];
47
 
48
  mediaRecorder.addEventListener("dataavailable", event => {
 
49
  audioChunks.push(event.data);
50
- });
51
-
52
- mediaRecorder.addEventListener("stop", () => {
53
- const audioBlob = new Blob(audioChunks);
54
- websocket.send(audioBlob);
55
  });
56
 
57
  websocket.onmessage = (event) => {
58
  const transcription = event.data;
59
  const transcriptionDiv = document.getElementById("transcription");
60
- transcriptionDiv.innerHTML += `<div>${transcription}</div>`;
 
61
  };
62
 
 
 
63
  websocket.onopen = () => {
64
- console.log('WebSocket connection established.');
65
  };
66
 
67
  websocket.onerror = (error) => {
68
- console.error('WebSocket error:', error);
69
  };
70
 
71
  websocket.onclose = () => {
72
- console.log('WebSocket connection closed.');
73
  };
74
-
75
- mediaRecorder.start(1000);
76
  };
77
 
78
  navigator.mediaDevices.getUserMedia({ audio: true })
79
  .then(handleAudio)
80
- .catch(error => console.error('Error accessing media devices.', error));
81
  </script>
82
 
83
  <div id="transcription">Your transcriptions will appear here:</div>
84
  """, unsafe_allow_html=True)
85
 
86
- # To run the WebSocket server
87
  if __name__ == "__main__":
88
- asyncio.get_event_loop().run_until_complete(main_logic())
 
13
 
14
  async def recognize_speech(websocket):
15
  async for message in websocket:
16
+ wf, samplerate = sf.read(io.BytesIO(message))
17
+ input_values = tokenizer(wf, return_tensors="pt").input_values
18
+ with torch.no_grad():
19
+ logits = model(input_values).logits
20
+
21
+ predicted_ids = torch.argmax(logits, dim=-1)
22
+ transcription = tokenizer.decode(predicted_ids[0])
23
+ await websocket.send(transcription)
 
 
 
 
 
 
 
 
24
 
25
  async def main_logic():
26
  async with websockets.serve(recognize_speech, "localhost", 8000):
27
  await asyncio.Future() # run forever
28
 
29
+ # Create the streamlit interface
30
+ st.title("Real-Time ASR with Transformers.js")
31
 
32
+ # The script can't be run via "streamlit run" because that hangs asyncio loop
33
  st.markdown("""
34
  <script>
35
  const handleAudio = async (stream) => {
36
  const websocket = new WebSocket('ws://localhost:8000');
37
+ const mediaRecorder = new MediaRecorder(stream, {mimeType: 'audio/webm'});
38
  const audioChunks = [];
39
 
40
  mediaRecorder.addEventListener("dataavailable", event => {
41
+ console.log('dataavailable:', event.data);
42
  audioChunks.push(event.data);
43
+ websocket.send(event.data);
 
 
 
 
44
  });
45
 
46
  websocket.onmessage = (event) => {
47
  const transcription = event.data;
48
  const transcriptionDiv = document.getElementById("transcription");
49
+ transcriptionDiv.innerHTML = transcriptionDiv.innerHTML + transcription + "<br/>";
50
+ console.log('Received:', transcription);
51
  };
52
 
53
+ mediaRecorder.start(1000);
54
+
55
  websocket.onopen = () => {
56
+ console.log('Connected to WebSocket');
57
  };
58
 
59
  websocket.onerror = (error) => {
60
+ console.error('WebSocket Error:', error);
61
  };
62
 
63
  websocket.onclose = () => {
64
+ console.log('WebSocket Closed');
65
  };
 
 
66
  };
67
 
68
  navigator.mediaDevices.getUserMedia({ audio: true })
69
  .then(handleAudio)
70
+ .catch(error => console.error('getUserMedia Error:', error));
71
  </script>
72
 
73
  <div id="transcription">Your transcriptions will appear here:</div>
74
  """, unsafe_allow_html=True)
75
 
 
76
  if __name__ == "__main__":
77
+ asyncio.run(main_logic())