228 lines
7.9 KiB
Python
228 lines
7.9 KiB
Python
#
|
|
# Copyright (c) 2025, Daily
|
|
#
|
|
# SPDX-License-Identifier: BSD 2-Clause License
|
|
#
|
|
|
|
import datetime
|
|
import io
|
|
import os
|
|
import sys
|
|
import wave
|
|
|
|
import aiofiles
|
|
from dotenv import load_dotenv
|
|
from fastapi import WebSocket
|
|
from loguru import logger
|
|
|
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
|
from pipecat.pipeline.pipeline import Pipeline
|
|
from pipecat.pipeline.runner import PipelineRunner
|
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
|
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
|
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
|
from pipecat.serializers.twilio import TwilioFrameSerializer
|
|
|
|
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
|
from pipecat.services.playht import PlayHTTTSService, Language
|
|
from pipecat.services.deepgram import DeepgramSTTService
|
|
from pipecat.services.fish import FishAudioTTSService
|
|
from pipecat.services.rime import RimeTTSService
|
|
from pipecat.services.cartesia import CartesiaTTSService
|
|
|
|
from pipecat.services.openai_realtime_beta import (
|
|
OpenAIRealtimeBetaLLMService,
|
|
SessionProperties,
|
|
TurnDetection,
|
|
)
|
|
from pipecat.services.anthropic import AnthropicLLMService
|
|
from pipecat.services.openai import OpenAILLMService
|
|
from pipecat.services.google import GoogleLLMService, GoogleLLMContext
|
|
from pipecat.transports.network.fastapi_websocket import (
|
|
FastAPIWebsocketParams,
|
|
FastAPIWebsocketTransport,
|
|
)
|
|
|
|
load_dotenv(override=True)
|
|
|
|
logger.remove(0)
|
|
logger.add(sys.stderr, level="DEBUG")
|
|
|
|
|
|
async def save_audio(
|
|
server_name: str, audio: bytes, sample_rate: int, num_channels: int
|
|
):
|
|
if len(audio) > 0:
|
|
filename = f"{server_name}_recording_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
|
|
with io.BytesIO() as buffer:
|
|
with wave.open(buffer, "wb") as wf:
|
|
wf.setsampwidth(2)
|
|
wf.setnchannels(num_channels)
|
|
wf.setframerate(sample_rate)
|
|
wf.writeframes(audio)
|
|
async with aiofiles.open(filename, "wb") as file:
|
|
await file.write(buffer.getvalue())
|
|
logger.info(f"Merged audio saved to {filename}")
|
|
else:
|
|
logger.info("No audio data to save")
|
|
|
|
|
|
async def run_bot(
|
|
websocket_client: WebSocket, stream_sid: str, testing: bool, option: int = 1
|
|
):
|
|
transport = FastAPIWebsocketTransport(
|
|
websocket=websocket_client,
|
|
params=FastAPIWebsocketParams(
|
|
audio_in_enabled=True,
|
|
audio_out_enabled=True,
|
|
add_wav_header=False,
|
|
vad_enabled=True,
|
|
vad_analyzer=SileroVADAnalyzer(),
|
|
vad_audio_passthrough=True,
|
|
serializer=TwilioFrameSerializer(stream_sid),
|
|
),
|
|
)
|
|
|
|
# llm = OpenAIRealtimeBetaLLMService(
|
|
# api_key=os.getenv("OPENAI_API_KEY"),
|
|
# session_properties=SessionProperties(
|
|
# modalities=["text"],
|
|
# turn_detection=TurnDetection(threshold=0.5, silence_duration_ms=800),
|
|
# voice=None,
|
|
# ),
|
|
# )
|
|
|
|
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
|
|
|
# llm = AnthropicLLMService(api_key=os.getenv("ANTRHOPIC_API_KEY"))
|
|
|
|
# llm = GoogleLLMService(api_key=os.getenv("GOOGLE_API_KEY"), model="phone_call")
|
|
|
|
stt = DeepgramSTTService(
|
|
api_key=os.getenv("DEEPGRAM_API_KEY"), audio_passthrough=True
|
|
)
|
|
|
|
# tts = PlayHTTTSService(
|
|
# api_key=os.getenv("PLAYHT_SECRE_KEY"),
|
|
# user_id=os.getenv("PLAYHT_USERID"),
|
|
# voice_url="s3://voice-cloning-zero-shot/80ba8839-a6e6-470c-8f68-7c1e5d3ee2ff/abigailsaad/manifest.json",
|
|
# params=PlayHTTTSService.InputParams(
|
|
# language=Language.EN,
|
|
# speed=1.0,
|
|
# ),
|
|
# ) # not working
|
|
|
|
# tts = FishAudioTTSService(
|
|
# api_key=os.getenv("FISH_AUDIO_API_KEY"),
|
|
# model="b545c585f631496c914815291da4e893", # Get this from Fish Audio playground
|
|
# output_format="pcm", # Choose output format
|
|
# sample_rate=24000, # Set sample rate
|
|
# params=FishAudioTTSService.InputParams(latency="normal", prosody_speed=1.0),
|
|
# ) # not working
|
|
|
|
if option == 1:
|
|
tts = CartesiaTTSService(
|
|
api_key=os.getenv("CARTESIA_API_KEY"),
|
|
voice_id="156fb8d2-335b-4950-9cb3-a2d33befec77", # British Lady
|
|
push_silence_after_stop=testing,
|
|
)
|
|
elif option == 2:
|
|
tts = RimeTTSService(
|
|
api_key=os.getenv("RIME_API_KEY"),
|
|
voice_id="stream",
|
|
model="mistv2",
|
|
)
|
|
elif option == 3:
|
|
tts = ElevenLabsTTSService(
|
|
api_key=os.getenv("ELEVEN_LABS_API_KEY"),
|
|
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22",
|
|
push_silence_after_stop=testing,
|
|
)
|
|
elif option == 4:
|
|
tts = RimeTTSService(
|
|
api_key=os.getenv("RIME_API_KEY"),
|
|
voice_id="breeze",
|
|
model="mistv2",
|
|
)
|
|
elif option == 5:
|
|
tts = CartesiaTTSService(
|
|
api_key=os.getenv("CARTESIA_API_KEY"),
|
|
voice_id="1d3ba41a-96e6-44ad-aabb-9817c56caa68", # British Lady
|
|
push_silence_after_stop=testing,
|
|
)
|
|
else:
|
|
tts = RimeTTSService(
|
|
api_key=os.getenv("RIME_API_KEY"),
|
|
voice_id="peak",
|
|
model="mistv2",
|
|
)
|
|
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": f"""
|
|
Welcome to 365 Days Medical Centre Para Hills - we care about you.
|
|
If this is an emergency, please call triple zero.
|
|
We are open from 8 AM to 8 PM every day of the year.
|
|
All calls are recorded for training and quality purposes - please let us know if you do not wish to be recorded.
|
|
I am Nishka, your 24/7 healthcare receptionist. Which language would you like to speak?
|
|
""",
|
|
}
|
|
]
|
|
|
|
context = OpenAILLMContext(messages)
|
|
context_aggregator = llm.create_context_aggregator(context)
|
|
|
|
# NOTE: Watch out! This will save all the conversation in memory. You can
|
|
# pass `buffer_size` to get periodic callbacks.
|
|
audiobuffer = AudioBufferProcessor(user_continuous_stream=not testing)
|
|
|
|
pipeline = Pipeline(
|
|
[
|
|
transport.input(), # Websocket input from client
|
|
stt, # Speech-To-Text
|
|
context_aggregator.user(), # User context
|
|
llm, # LLM
|
|
tts, # Text-To-Speech
|
|
transport.output(), # Websocket output to client
|
|
audiobuffer, # Used to buffer the audio in the pipeline
|
|
context_aggregator.assistant(), # Assistant context
|
|
]
|
|
)
|
|
|
|
task = PipelineTask(
|
|
pipeline,
|
|
params=PipelineParams(
|
|
audio_in_sample_rate=8000,
|
|
audio_out_sample_rate=8000,
|
|
allow_interruptions=True,
|
|
),
|
|
)
|
|
|
|
@transport.event_handler("on_client_connected")
|
|
async def on_client_connected(transport, client):
|
|
# Start recording.
|
|
await audiobuffer.start_recording()
|
|
# Kick off the conversation.
|
|
messages.append(
|
|
{"role": "system", "content": "Please introduce yourself to the user."}
|
|
)
|
|
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
|
|
|
@transport.event_handler("on_client_disconnected")
|
|
async def on_client_disconnected(transport, client):
|
|
await task.cancel()
|
|
|
|
# @audiobuffer.event_handler("on_audio_data")
|
|
# async def on_audio_data(buffer, audio, sample_rate, num_channels):
|
|
# server_name = f"server_{websocket_client.client.port}"
|
|
# await save_audio(server_name, audio, sample_rate, num_channels)
|
|
|
|
# We use `handle_sigint=False` because `uvicorn` is controlling keyboard
|
|
# interruptions. We use `force_gc=True` to force garbage collection after
|
|
# the runner finishes running a task which could be useful for long running
|
|
# applications with multiple clients connecting.
|
|
runner = PipelineRunner(handle_sigint=False, force_gc=True)
|
|
|
|
await runner.run(task)
|