File size: 10,401 Bytes
a4181e3 0206ee8 a4181e3 2032c87 50babed de2f549 a4181e3 de2f549 a4181e3 de2f549 8a3b4aa 2b054c9 15b3060 74e732d a4181e3 de2f549 a4181e3 de2f549 a4181e3 1027960 de2f549 1027960 de2f549 1027960 de2f549 1027960 a4181e3 de2f549 a4181e3 06526ee a4181e3 d1c4428 a4181e3 de2f549 a4181e3 d1c4428 a4181e3 de2f549 a4181e3 2b054c9 a4181e3 589cb0a a4181e3 589cb0a d1c4428 589cb0a d1c4428 589cb0a d1c4428 589cb0a d1c4428 589cb0a d1c4428 589cb0a d1c4428 589cb0a d1c4428 589cb0a d1c4428 589cb0a d1c4428 a4181e3 2b054c9 de2f549 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
import time
import asyncio
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from silero_vad import VADIterator, load_silero_vad
from transformers import AutoProcessor, pipeline, WhisperTokenizerFast
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
# Load models
processor_tiny = AutoProcessor.from_pretrained("onnx-community/whisper-tiny.en")
model_tiny = ORTModelForSpeechSeq2Seq.from_pretrained("onnx-community/whisper-tiny.en", subfolder="onnx")
tokenizer_tiny = WhisperTokenizerFast.from_pretrained("onnx-community/whisper-tiny.en", language="english")
pipe_tiny = pipeline("automatic-speech-recognition", model=model_tiny, tokenizer=tokenizer_tiny, feature_extractor=processor_tiny.feature_extractor)
processor_base = AutoProcessor.from_pretrained("onnx-community/whisper-base.en")
model_base = ORTModelForSpeechSeq2Seq.from_pretrained("onnx-community/whisper-base.en", subfolder="onnx")
tokenizer_base = WhisperTokenizerFast.from_pretrained("onnx-community/whisper-base.en", language="english")
pipe_base = pipeline("automatic-speech-recognition", model=model_base, tokenizer=tokenizer_base, feature_extractor=processor_base.feature_extractor)
# Constants
SAMPLING_RATE = 16000
CHUNK_SIZE = 512
LOOKBACK_CHUNKS = 5
MAX_SPEECH_SECS = 15
MIN_REFRESH_SECS = 1
app = FastAPI()
vad_model = load_silero_vad(onnx=True)
vad_iterator = VADIterator(
model=vad_model,
sampling_rate=SAMPLING_RATE,
threshold=0.5,
min_silence_duration_ms=300,
)
def pcm16_to_float32(pcm_data: bytes) -> np.ndarray:
"""
Convert 16-bit PCM bytes into a float32 numpy array with values in [-1, 1].
"""
int_data = np.frombuffer(pcm_data, dtype=np.int16)
float_data = int_data.astype(np.float32) / 32768.0
return float_data
@app.websocket("/ws/transcribe")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
caption_cache = []
speech = np.empty(0, dtype=np.float32)
recording = False
last_partial_time = time.time()
current_pipe = pipe_tiny
try:
while True:
data = await websocket.receive()
if data["type"] == "websocket.receive":
if data.get("text") == "switch_to_tiny":
current_pipe = pipe_tiny
continue
elif data.get("text") == "switch_to_base":
current_pipe = pipe_base
continue
chunk = pcm16_to_float32(data["bytes"])
speech = np.concatenate((speech, chunk))
if not recording:
speech = speech[-(LOOKBACK_CHUNKS * CHUNK_SIZE):]
vad_result = vad_iterator(chunk)
if vad_result:
if "start" in vad_result and not recording:
recording = True
await websocket.send_json({"type": "status", "message": "speaking_started"})
if "end" in vad_result and recording:
recording = False
text = current_pipe({"sampling_rate": 16000, "raw": speech})["text"]
await websocket.send_json({"type": "final", "transcript": text})
caption_cache.append(text)
speech = np.empty(0, dtype=np.float32)
vad_iterator.triggered = False
vad_iterator.temp_end = 0
vad_iterator.current_sample = 0
await websocket.send_json({"type": "status", "message": "speaking_stopped"})
except WebSocketDisconnect:
if recording and speech.size:
text = current_pipe({"sampling_rate": 16000, "raw": speech})["text"]
await websocket.send_json({"type": "final", "transcript": text})
print("WebSocket disconnected")
@app.get("/", response_class=HTMLResponse)
async def get_home():
return """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>AssemblyAI Realtime Transcription</title>
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/tailwind.min.css" rel="stylesheet">
</head>
<body class="bg-gray-100 p-6">
<div class="max-w-3xl mx-auto bg-white p-6 rounded-lg shadow-md">
<h1 class="text-2xl font-bold mb-4">Realtime Transcription</h1>
<button onclick="startTranscription()" class="bg-blue-500 text-white px-4 py-2 rounded mb-4">Start Transcription</button>
<select id="modelSelect" onchange="switchModel()" class="bg-gray-200 px-4 py-2 rounded mb-4">
<option value="tiny">Tiny Model</option>
<option value="base">Base Model</option>
</select>
<p id="status" class="text-gray-600 mb-4">Click start to begin transcription.</p>
<p id="speakingStatus" class="text-gray-600 mb-4"></p>
<div id="transcription" class="border p-4 rounded mb-4 h-64 overflow-auto"></div>
<div id="visualizer" class="border p-4 rounded h-64">
<canvas id="audioCanvas" class="w-full h-full"></canvas>
</div>
</div>
<script>
let ws;
let audioContext;
let scriptProcessor;
let mediaStream;
let currentLine = document.createElement('span');
let analyser;
let canvas, canvasContext;
document.getElementById('transcription').appendChild(currentLine);
canvas = document.getElementById('audioCanvas');
canvasContext = canvas.getContext('2d');
async function startTranscription() {
document.getElementById("status").innerText = "Connecting...";
ws = new WebSocket("wss://" + location.host + "/ws/transcribe");
ws.binaryType = 'arraybuffer';
ws.onopen = async function() {
document.getElementById("status").innerText = "Connected";
try {
mediaStream = await navigator.mediaDevices.getUserMedia({ audio: true });
audioContext = new AudioContext({ sampleRate: 16000 });
const source = audioContext.createMediaStreamSource(mediaStream);
analyser = audioContext.createAnalyser();
analyser.fftSize = 2048;
const bufferLength = analyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
source.connect(analyser);
scriptProcessor = audioContext.createScriptProcessor(512, 1, 1);
scriptProcessor.onaudioprocess = function(event) {
const inputData = event.inputBuffer.getChannelData(0);
const pcm16 = floatTo16BitPCM(inputData);
if (ws.readyState === WebSocket.OPEN) {
ws.send(pcm16);
}
analyser.getByteTimeDomainData(dataArray);
canvasContext.fillStyle = 'rgb(200, 200, 200)';
canvasContext.fillRect(0, 0, canvas.width, canvas.height);
canvasContext.lineWidth = 2;
canvasContext.strokeStyle = 'rgb(0, 0, 0)';
canvasContext.beginPath();
let sliceWidth = canvas.width * 1.0 / bufferLength;
let x = 0;
for (let i = 0; i < bufferLength; i++) {
let v = dataArray[i] / 128.0;
let y = v * canvas.height / 2;
if (i === 0) {
canvasContext.moveTo(x, y);
} else {
canvasContext.lineTo(x, y);
}
x += sliceWidth;
}
canvasContext.lineTo(canvas.width, canvas.height / 2);
canvasContext.stroke();
};
source.connect(scriptProcessor);
scriptProcessor.connect(audioContext.destination);
} catch (err) {
document.getElementById("status").innerText = "Error: " + err;
}
};
ws.onmessage = function(event) {
const data = JSON.parse(event.data);
if (data.type === 'partial') {
currentLine.style.color = 'gray';
currentLine.textContent = data.transcript + ' ';
} else if (data.type === 'final') {
currentLine.style.color = 'black';
currentLine.textContent = data.transcript;
currentLine = document.createElement('span');
document.getElementById('transcription').appendChild(document.createElement('br'));
document.getElementById('transcription').appendChild(currentLine);
} else if (data.type === 'status') {
if (data.message === 'speaking_started') {
document.getElementById("speakingStatus").innerText = "Speaking Started";
document.getElementById("speakingStatus").style.color = "green";
} else if (data.message === 'speaking_stopped') {
document.getElementById("speakingStatus").innerText = "Speaking Stopped";
document.getElementById("speakingStatus").style.color = "red";
}
}
};
ws.onclose = function() {
if (audioContext && audioContext.state !== 'closed') {
audioContext.close();
}
document.getElementById("status").innerText = "Closed";
};
}
function switchModel() {
const model = document.getElementById("modelSelect").value;
if (ws && ws.readyState === WebSocket.OPEN) {
if (model === "tiny") {
ws.send("switch_to_tiny");
} else if (model === "base") {
ws.send("switch_to_base");
}
}
}
function floatTo16BitPCM(input) {
const buffer = new ArrayBuffer(input.length * 2);
const output = new DataView(buffer);
for (let i = 0; i < input.length; i++) {
let s = Math.max(-1, Math.min(1, input[i]));
output.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
}
return buffer;
}
</script>
</body>
</html>
"""
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |