health-apps-cms/services/callTranscripts.py

186 lines
7.6 KiB
Python

import datetime
from typing import Optional
from fastapi import BackgroundTasks
from sqlalchemy.orm import Session
import tempfile
import zipfile
import time
from fastapi.responses import FileResponse
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from sqlalchemy import desc
from schemas.ResponseSchemas import CallTranscriptsResponse
from database import get_db
from models.CallTranscripts import CallTranscripts
from exceptions.business_exception import BusinessValidationException
from services.s3Service import get_signed_url
from interface.common_response import CommonResponse
from loguru import logger
from schemas.CreateSchemas import CallTranscriptsCreate
from exceptions.db_exceptions import DBExceptionHandler
class CallTranscriptServices:
def __init__(self):
self.db:Session = next(get_db())
self.logger = logger
async def create_call_transcript(self, data:CallTranscriptsCreate):
try:
call_transcript = CallTranscripts(**data.model_dump())
self.db.add(call_transcript)
self.db.commit()
return
except Exception as e:
DBExceptionHandler.handle_exception(e, context="creating call transcript")
finally:
self.db.close()
async def get_call_transcripts(self, limit:int, offset:int, search: str = "", orderBy: str = "call_received_time", order: str = "ASC", startDate: Optional[datetime.datetime] = None, endDate: Optional[datetime.datetime] = None):
try:
query = self.db.query(CallTranscripts).order_by(desc(getattr(CallTranscripts, orderBy)) if order == "DESC" else getattr(CallTranscripts, orderBy))
if search:
query = query.filter(CallTranscripts.patient_number.contains(search))
if startDate and endDate:
query = query.filter(CallTranscripts.call_received_time.between(startDate, endDate))
call_transcripts = query.limit(limit).offset(offset).all()
total = self.db.query(CallTranscripts).count()
response = [CallTranscriptsResponse(**call_transcript.__dict__.copy()) for call_transcript in call_transcripts]
for call_transcript in response:
call_transcript.transcript_key_id = await get_signed_url(call_transcript.transcript_key_id)
return_response = CommonResponse(data=response, total=total)
return return_response
except Exception as e:
DBExceptionHandler.handle_exception(e, context="getting call transcripts")
finally:
self.db.close()
async def download_call_transcript(self, key_id: str):
try:
call_transcript = self.db.query(CallTranscripts).filter(CallTranscripts.transcript_key_id == key_id).first()
if not call_transcript:
raise BusinessValidationException("Call transcript not found!")
return get_signed_url(call_transcript.transcript_key_id)
except Exception as e:
DBExceptionHandler.handle_exception(e, context="downloading call transcript")
finally:
self.db.close()
def download_file(self, url: str, file_path: str) -> None:
"""
Download a file from a signed URL to a local path.
Args:
url: The pre-signed URL to download from
file_path: The local path to save the file to
"""
try:
import requests
response = requests.get(url)
if response.status_code == 200:
with open(file_path, 'wb') as f:
f.write(response.content)
else:
print(f"Failed to download file: {response.status_code}")
except Exception as e:
print(f"Error downloading file: {e}")
def cleanup_temp_files(self, temp_dir: str, zip_path: str) -> None:
"""
Clean up temporary files after sending the zip.
Args:
temp_dir: Directory containing temporary files
zip_path: Path to the zip file
"""
try:
# Wait a short time to ensure the file has been sent
time.sleep(5)
# Remove the zip file
if os.path.exists(zip_path):
os.remove(zip_path)
# Remove the temp directory and all its contents
if os.path.exists(temp_dir):
for file in os.listdir(temp_dir):
os.remove(os.path.join(temp_dir, file))
os.rmdir(temp_dir)
except Exception as e:
print(f"Error during cleanup: {e}")
async def bulk_download_call_transcripts(self, key_ids: list[int], background_tasks: BackgroundTasks):
try:
transcript_ids = self.db.query(CallTranscripts).filter(CallTranscripts.id.in_(key_ids)).all()
keys = [transcript.transcript_key_id for transcript in transcript_ids]
if len(keys) < 1:
raise BusinessValidationException("No call transcripts found!")
temp_dir = tempfile.mkdtemp(prefix="call_transcripts_")
zip_path = os.path.join(temp_dir, "call_transcripts.zip")
# Prepare download information
download_info = []
for key in keys:
# Generate signed URL for each key
url = await get_signed_url(key)
# Determine filename (using key's basename or a formatted name)
filename = os.path.basename(key)
file_path = os.path.join(temp_dir, filename)
download_info.append((url, file_path, filename))
# Use ThreadPoolExecutor for concurrent downloads
# Adjust max_workers based on your system capabilities and S3 rate limits
max_workers = min(32, len(download_info)) # Cap at 32 threads or number of files, whichever is smaller
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all download tasks
future_to_file = {executor.submit(self.download_file, url, file_path): (file_path, filename)
for url, file_path, filename in download_info}
# Collect results as they complete
file_paths = []
for future in as_completed(future_to_file):
file_path, filename = future_to_file[future]
try:
future.result() # Get the result to catch any exceptions
file_paths.append((file_path, filename))
except Exception as e:
print(f"Error downloading {filename}: {e}")
# Create zip file from downloaded files
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zip_file:
for file_path, arcname in file_paths:
if os.path.exists(file_path):
zip_file.write(file_path, arcname=arcname)
# Add cleanup task to run after response is sent
# background_tasks.add_task(self.cleanup_temp_files, temp_dir, zip_path)
# Return the zip file as a response
return FileResponse(
path=zip_path,
media_type="application/zip",
filename="call_transcripts.zip",
# background=background_tasks
)
except Exception as e:
DBExceptionHandler.handle_exception(e, context="bulk downloading call transcripts")
finally:
self.db.close()