Add Twilio WebSocket media stream handler with real-time transcription
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
144
backend/main.py
144
backend/main.py
@@ -4,10 +4,14 @@ import uuid
|
|||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.responses import FileResponse, Response
|
from fastapi.responses import FileResponse, Response
|
||||||
from twilio.twiml.voice_response import VoiceResponse
|
from twilio.twiml.voice_response import VoiceResponse
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
import audioop
|
||||||
|
import time
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -849,6 +853,144 @@ async def drop_from_queue(call_sid: str):
|
|||||||
return {"status": "dropped"}
|
return {"status": "dropped"}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Twilio WebSocket Media Stream ---
|
||||||
|
|
||||||
|
@app.websocket("/api/twilio/stream")
|
||||||
|
async def twilio_media_stream(websocket: WebSocket):
|
||||||
|
"""Handle Twilio Media Streams WebSocket — bidirectional audio"""
|
||||||
|
await websocket.accept()
|
||||||
|
print("[Twilio WS] Media stream connected")
|
||||||
|
|
||||||
|
call_sid = None
|
||||||
|
stream_sid = None
|
||||||
|
audio_buffer = bytearray()
|
||||||
|
CHUNK_DURATION_S = 3 # Transcribe every 3 seconds of audio
|
||||||
|
MULAW_SAMPLE_RATE = 8000
|
||||||
|
chunk_samples = CHUNK_DURATION_S * MULAW_SAMPLE_RATE
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
data = await websocket.receive_text()
|
||||||
|
msg = json.loads(data)
|
||||||
|
event = msg.get("event")
|
||||||
|
|
||||||
|
if event == "start":
|
||||||
|
stream_sid = msg["start"]["streamSid"]
|
||||||
|
call_sid = msg["start"]["callSid"]
|
||||||
|
print(f"[Twilio WS] Stream started: {stream_sid} for call {call_sid}")
|
||||||
|
|
||||||
|
elif event == "media":
|
||||||
|
# Decode mulaw audio from base64
|
||||||
|
payload = base64.b64decode(msg["media"]["payload"])
|
||||||
|
# Convert mulaw to 16-bit PCM
|
||||||
|
pcm_data = audioop.ulaw2lin(payload, 2)
|
||||||
|
audio_buffer.extend(pcm_data)
|
||||||
|
|
||||||
|
# Get channel for this caller
|
||||||
|
call_info = twilio_service.active_calls.get(call_sid)
|
||||||
|
if call_info:
|
||||||
|
channel = call_info["channel"]
|
||||||
|
# Route PCM to the caller's dedicated Loopback channel
|
||||||
|
audio_service.route_real_caller_audio(pcm_data, channel, MULAW_SAMPLE_RATE)
|
||||||
|
|
||||||
|
# When we have enough audio, transcribe
|
||||||
|
if len(audio_buffer) >= chunk_samples * 2: # 2 bytes per sample
|
||||||
|
pcm_chunk = bytes(audio_buffer[:chunk_samples * 2])
|
||||||
|
audio_buffer = audio_buffer[chunk_samples * 2:]
|
||||||
|
|
||||||
|
# Transcribe in background
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_real_caller_transcription(call_sid, pcm_chunk, MULAW_SAMPLE_RATE)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event == "stop":
|
||||||
|
print(f"[Twilio WS] Stream stopped: {stream_sid}")
|
||||||
|
break
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
print(f"[Twilio WS] Disconnected: {call_sid}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Twilio WS] Error: {e}")
|
||||||
|
finally:
|
||||||
|
# Transcribe any remaining audio
|
||||||
|
if audio_buffer and call_sid:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_real_caller_transcription(call_sid, bytes(audio_buffer), MULAW_SAMPLE_RATE)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_real_caller_transcription(call_sid: str, pcm_data: bytes, sample_rate: int):
|
||||||
|
"""Transcribe a chunk of real caller audio and add to conversation"""
|
||||||
|
call_info = twilio_service.active_calls.get(call_sid)
|
||||||
|
if not call_info:
|
||||||
|
return
|
||||||
|
|
||||||
|
text = await transcribe_audio(pcm_data, source_sample_rate=sample_rate)
|
||||||
|
if not text or not text.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
caller_name = call_info["name"]
|
||||||
|
print(f"[Real Caller] {caller_name}: {text}")
|
||||||
|
|
||||||
|
# Add to conversation with real_caller role
|
||||||
|
session.add_message(f"real_caller:{caller_name}", text)
|
||||||
|
|
||||||
|
# If AI auto-respond mode is on and an AI caller is active, check if AI should respond
|
||||||
|
if session.ai_respond_mode == "auto" and session.current_caller_key:
|
||||||
|
asyncio.create_task(_check_ai_auto_respond(text, caller_name))
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_ai_auto_respond(real_caller_text: str, real_caller_name: str):
|
||||||
|
"""Check if AI caller should jump in, and generate response if so"""
|
||||||
|
if not session.caller:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Cooldown check
|
||||||
|
if not hasattr(session, '_last_ai_auto_respond'):
|
||||||
|
session._last_ai_auto_respond = 0
|
||||||
|
if time.time() - session._last_ai_auto_respond < 10:
|
||||||
|
return
|
||||||
|
|
||||||
|
ai_name = session.caller["name"]
|
||||||
|
|
||||||
|
# Quick "should I respond?" check with minimal LLM call
|
||||||
|
should_respond = await llm_service.generate(
|
||||||
|
messages=[{"role": "user", "content": f'Someone just said: "{real_caller_text}". Should {ai_name} jump in? Reply only YES or NO.'}],
|
||||||
|
system_prompt=f"You're deciding if {ai_name} should respond to what was just said on a radio show. Say YES if it's interesting or relevant to them, NO if not.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if "YES" not in should_respond.upper():
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"[Auto-Respond] {ai_name} is jumping in...")
|
||||||
|
session._last_ai_auto_respond = time.time()
|
||||||
|
|
||||||
|
# Generate full response
|
||||||
|
conversation_summary = session.get_conversation_summary()
|
||||||
|
system_prompt = get_caller_prompt(session.caller, conversation_summary)
|
||||||
|
|
||||||
|
response = await llm_service.generate(
|
||||||
|
messages=session.conversation[-10:],
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
)
|
||||||
|
response = clean_for_tts(response)
|
||||||
|
if not response or not response.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
session.add_message(f"ai_caller:{ai_name}", response)
|
||||||
|
|
||||||
|
# Generate TTS and play
|
||||||
|
audio_bytes = await generate_speech(response, session.caller["voice"], "none")
|
||||||
|
|
||||||
|
import threading
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=audio_service.play_caller_audio,
|
||||||
|
args=(audio_bytes, 24000),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
|
||||||
# --- Server Control Endpoints ---
|
# --- Server Control Endpoints ---
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|||||||
@@ -309,6 +309,42 @@ class AudioService:
|
|||||||
"""Stop any playing caller audio"""
|
"""Stop any playing caller audio"""
|
||||||
self._caller_stop_event.set()
|
self._caller_stop_event.set()
|
||||||
|
|
||||||
|
def route_real_caller_audio(self, pcm_data: bytes, channel: int, sample_rate: int):
|
||||||
|
"""Route real caller PCM audio to a specific Loopback channel"""
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
if self.output_device is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert bytes to float32
|
||||||
|
audio = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
device_info = sd.query_devices(self.output_device)
|
||||||
|
num_channels = device_info['max_output_channels']
|
||||||
|
device_sr = int(device_info['default_samplerate'])
|
||||||
|
channel_idx = min(channel, num_channels) - 1
|
||||||
|
|
||||||
|
# Resample from Twilio's 8kHz to device sample rate
|
||||||
|
if sample_rate != device_sr:
|
||||||
|
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=device_sr)
|
||||||
|
|
||||||
|
# Create multi-channel output
|
||||||
|
multi_ch = np.zeros((len(audio), num_channels), dtype=np.float32)
|
||||||
|
multi_ch[:, channel_idx] = audio
|
||||||
|
|
||||||
|
# Write to output device
|
||||||
|
with sd.OutputStream(
|
||||||
|
device=self.output_device,
|
||||||
|
samplerate=device_sr,
|
||||||
|
channels=num_channels,
|
||||||
|
dtype=np.float32,
|
||||||
|
) as stream:
|
||||||
|
stream.write(multi_ch)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Real caller audio routing error: {e}")
|
||||||
|
|
||||||
# --- Music Playback ---
|
# --- Music Playback ---
|
||||||
|
|
||||||
def load_music(self, file_path: str) -> bool:
|
def load_music(self, file_path: str) -> bool:
|
||||||
|
|||||||
Reference in New Issue
Block a user