import datetime from typing import Optional from fastapi import BackgroundTasks from sqlalchemy.orm import Session import tempfile import zipfile import time from fastapi.responses import FileResponse import os from concurrent.futures import ThreadPoolExecutor, as_completed from sqlalchemy import desc from schemas.ResponseSchemas import CallTranscriptsResponse from database import get_db from models.CallTranscripts import CallTranscripts from exceptions.business_exception import BusinessValidationException from services.s3Service import get_signed_url from interface.common_response import CommonResponse from loguru import logger from schemas.CreateSchemas import CallTranscriptsCreate from exceptions.db_exceptions import DBExceptionHandler class CallTranscriptServices: def __init__(self): self.db:Session = next(get_db()) self.logger = logger async def create_call_transcript(self, data:CallTranscriptsCreate): try: call_transcript = CallTranscripts(**data.model_dump()) self.db.add(call_transcript) self.db.commit() return except Exception as e: DBExceptionHandler.handle_exception(e, context="creating call transcript") finally: self.db.close() async def get_call_transcripts(self, limit:int, offset:int, search: str = "", orderBy: str = "call_received_time", order: str = "ASC", startDate: Optional[datetime.datetime] = None, endDate: Optional[datetime.datetime] = None): try: query = self.db.query(CallTranscripts).order_by(desc(getattr(CallTranscripts, orderBy)) if order == "DESC" else getattr(CallTranscripts, orderBy)) if search: query = query.filter(CallTranscripts.patient_number.contains(search)) if startDate and endDate: query = query.filter(CallTranscripts.call_received_time.between(startDate, endDate)) call_transcripts = query.limit(limit).offset(offset).all() total = self.db.query(CallTranscripts).count() response = [CallTranscriptsResponse(**call_transcript.__dict__.copy()) for call_transcript in call_transcripts] for call_transcript in response: call_transcript.transcript_key_id = await get_signed_url(call_transcript.transcript_key_id) return_response = CommonResponse(data=response, total=total) return return_response except Exception as e: DBExceptionHandler.handle_exception(e, context="getting call transcripts") finally: self.db.close() async def download_call_transcript(self, key_id: str): try: call_transcript = self.db.query(CallTranscripts).filter(CallTranscripts.transcript_key_id == key_id).first() if not call_transcript: raise BusinessValidationException("Call transcript not found!") return get_signed_url(call_transcript.transcript_key_id) except Exception as e: DBExceptionHandler.handle_exception(e, context="downloading call transcript") finally: self.db.close() def download_file(self, url: str, file_path: str) -> None: """ Download a file from a signed URL to a local path. Args: url: The pre-signed URL to download from file_path: The local path to save the file to """ try: import requests response = requests.get(url) if response.status_code == 200: with open(file_path, 'wb') as f: f.write(response.content) else: print(f"Failed to download file: {response.status_code}") except Exception as e: print(f"Error downloading file: {e}") def cleanup_temp_files(self, temp_dir: str, zip_path: str) -> None: """ Clean up temporary files after sending the zip. Args: temp_dir: Directory containing temporary files zip_path: Path to the zip file """ try: # Wait a short time to ensure the file has been sent time.sleep(5) # Remove the zip file if os.path.exists(zip_path): os.remove(zip_path) # Remove the temp directory and all its contents if os.path.exists(temp_dir): for file in os.listdir(temp_dir): os.remove(os.path.join(temp_dir, file)) os.rmdir(temp_dir) except Exception as e: print(f"Error during cleanup: {e}") async def bulk_download_call_transcripts(self, key_ids: list[int], background_tasks: BackgroundTasks): try: transcript_ids = self.db.query(CallTranscripts).filter(CallTranscripts.id.in_(key_ids)).all() keys = [transcript.transcript_key_id for transcript in transcript_ids] if len(keys) < 1: raise BusinessValidationException("No call transcripts found!") temp_dir = tempfile.mkdtemp(prefix="call_transcripts_") zip_path = os.path.join(temp_dir, "call_transcripts.zip") # Prepare download information download_info = [] for key in keys: # Generate signed URL for each key url = await get_signed_url(key) # Determine filename (using key's basename or a formatted name) filename = os.path.basename(key) file_path = os.path.join(temp_dir, filename) download_info.append((url, file_path, filename)) # Use ThreadPoolExecutor for concurrent downloads # Adjust max_workers based on your system capabilities and S3 rate limits max_workers = min(32, len(download_info)) # Cap at 32 threads or number of files, whichever is smaller with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all download tasks future_to_file = {executor.submit(self.download_file, url, file_path): (file_path, filename) for url, file_path, filename in download_info} # Collect results as they complete file_paths = [] for future in as_completed(future_to_file): file_path, filename = future_to_file[future] try: future.result() # Get the result to catch any exceptions file_paths.append((file_path, filename)) except Exception as e: print(f"Error downloading {filename}: {e}") # Create zip file from downloaded files with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zip_file: for file_path, arcname in file_paths: if os.path.exists(file_path): zip_file.write(file_path, arcname=arcname) # Add cleanup task to run after response is sent # background_tasks.add_task(self.cleanup_temp_files, temp_dir, zip_path) # Return the zip file as a response return FileResponse( path=zip_path, media_type="application/zip", filename="call_transcripts.zip", # background=background_tasks ) except Exception as e: DBExceptionHandler.handle_exception(e, context="bulk downloading call transcripts") finally: self.db.close()