health-apps-backend/services/bot.py

228 lines
7.9 KiB
Python

#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import datetime
import io
import os
import sys
import wave
import aiofiles
from dotenv import load_dotenv
from fastapi import WebSocket
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
from pipecat.serializers.twilio import TwilioFrameSerializer
from pipecat.services.elevenlabs import ElevenLabsTTSService
from pipecat.services.playht import PlayHTTTSService, Language
from pipecat.services.deepgram import DeepgramSTTService
from pipecat.services.fish import FishAudioTTSService
from pipecat.services.rime import RimeTTSService
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.openai_realtime_beta import (
OpenAIRealtimeBetaLLMService,
SessionProperties,
TurnDetection,
)
from pipecat.services.anthropic import AnthropicLLMService
from pipecat.services.openai import OpenAILLMService
from pipecat.services.google import GoogleLLMService, GoogleLLMContext
from pipecat.transports.network.fastapi_websocket import (
FastAPIWebsocketParams,
FastAPIWebsocketTransport,
)
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
async def save_audio(
server_name: str, audio: bytes, sample_rate: int, num_channels: int
):
if len(audio) > 0:
filename = f"{server_name}_recording_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
with io.BytesIO() as buffer:
with wave.open(buffer, "wb") as wf:
wf.setsampwidth(2)
wf.setnchannels(num_channels)
wf.setframerate(sample_rate)
wf.writeframes(audio)
async with aiofiles.open(filename, "wb") as file:
await file.write(buffer.getvalue())
logger.info(f"Merged audio saved to {filename}")
else:
logger.info("No audio data to save")
async def run_bot(
websocket_client: WebSocket, stream_sid: str, testing: bool, option: int = 1
):
transport = FastAPIWebsocketTransport(
websocket=websocket_client,
params=FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
add_wav_header=False,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
vad_audio_passthrough=True,
serializer=TwilioFrameSerializer(stream_sid),
),
)
# llm = OpenAIRealtimeBetaLLMService(
# api_key=os.getenv("OPENAI_API_KEY"),
# session_properties=SessionProperties(
# modalities=["text"],
# turn_detection=TurnDetection(threshold=0.5, silence_duration_ms=800),
# voice=None,
# ),
# )
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
# llm = AnthropicLLMService(api_key=os.getenv("ANTRHOPIC_API_KEY"))
# llm = GoogleLLMService(api_key=os.getenv("GOOGLE_API_KEY"), model="phone_call")
stt = DeepgramSTTService(
api_key=os.getenv("DEEPGRAM_API_KEY"), audio_passthrough=True
)
# tts = PlayHTTTSService(
# api_key=os.getenv("PLAYHT_SECRE_KEY"),
# user_id=os.getenv("PLAYHT_USERID"),
# voice_url="s3://voice-cloning-zero-shot/80ba8839-a6e6-470c-8f68-7c1e5d3ee2ff/abigailsaad/manifest.json",
# params=PlayHTTTSService.InputParams(
# language=Language.EN,
# speed=1.0,
# ),
# ) # not working
# tts = FishAudioTTSService(
# api_key=os.getenv("FISH_AUDIO_API_KEY"),
# model="b545c585f631496c914815291da4e893", # Get this from Fish Audio playground
# output_format="pcm", # Choose output format
# sample_rate=24000, # Set sample rate
# params=FishAudioTTSService.InputParams(latency="normal", prosody_speed=1.0),
# ) # not working
if option == 1:
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="156fb8d2-335b-4950-9cb3-a2d33befec77", # British Lady
push_silence_after_stop=testing,
)
elif option == 2:
tts = RimeTTSService(
api_key=os.getenv("RIME_API_KEY"),
voice_id="stream",
model="mistv2",
)
elif option == 3:
tts = ElevenLabsTTSService(
api_key=os.getenv("ELEVEN_LABS_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22",
push_silence_after_stop=testing,
)
elif option == 4:
tts = RimeTTSService(
api_key=os.getenv("RIME_API_KEY"),
voice_id="breeze",
model="mistv2",
)
elif option == 5:
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="1d3ba41a-96e6-44ad-aabb-9817c56caa68", # British Lady
push_silence_after_stop=testing,
)
else:
tts = RimeTTSService(
api_key=os.getenv("RIME_API_KEY"),
voice_id="peak",
model="mistv2",
)
messages = [
{
"role": "system",
"content": f"""
Welcome to 365 Days Medical Centre Para Hills - we care about you.
If this is an emergency, please call triple zero.
We are open from 8 AM to 8 PM every day of the year.
All calls are recorded for training and quality purposes - please let us know if you do not wish to be recorded.
I am Nishka, your 24/7 healthcare receptionist. Which language would you like to speak?
""",
}
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
# NOTE: Watch out! This will save all the conversation in memory. You can
# pass `buffer_size` to get periodic callbacks.
audiobuffer = AudioBufferProcessor(user_continuous_stream=not testing)
pipeline = Pipeline(
[
transport.input(), # Websocket input from client
stt, # Speech-To-Text
context_aggregator.user(), # User context
llm, # LLM
tts, # Text-To-Speech
transport.output(), # Websocket output to client
audiobuffer, # Used to buffer the audio in the pipeline
context_aggregator.assistant(), # Assistant context
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
audio_in_sample_rate=8000,
audio_out_sample_rate=8000,
allow_interruptions=True,
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
# Start recording.
await audiobuffer.start_recording()
# Kick off the conversation.
messages.append(
{"role": "system", "content": "Please introduce yourself to the user."}
)
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
await task.cancel()
# @audiobuffer.event_handler("on_audio_data")
# async def on_audio_data(buffer, audio, sample_rate, num_channels):
# server_name = f"server_{websocket_client.client.port}"
# await save_audio(server_name, audio, sample_rate, num_channels)
# We use `handle_sigint=False` because `uvicorn` is controlling keyboard
# interruptions. We use `force_gc=True` to force garbage collection after
# the runner finishes running a task which could be useful for long running
# applications with multiple clients connecting.
runner = PipelineRunner(handle_sigint=False, force_gc=True)
await runner.run(task)