139 lines
3.8 KiB
Python
139 lines
3.8 KiB
Python
from enum import Enum
|
|
from typing import Optional, Dict, Any
|
|
import os
|
|
from datetime import datetime
|
|
from urllib.parse import urlparse
|
|
|
|
import boto3
|
|
from botocore.config import Config
|
|
from botocore.exceptions import ClientError
|
|
from fastapi import HTTPException
|
|
from pydantic_settings import BaseSettings
|
|
from enums.enums import S3FolderNameEnum
|
|
from exceptions.business_exception import BusinessValidationException
|
|
|
|
|
|
class Settings(BaseSettings):
|
|
AWS_REGION: str
|
|
AWS_ACCESS_KEY: str
|
|
AWS_SECRET_KEY: str
|
|
AWS_BUCKET_NAME: str
|
|
AWS_S3_EXPIRES: int = 60 * 60 # Default 1 hour
|
|
|
|
class Config:
|
|
env_file = ".env"
|
|
extra = "ignore" # Allow extra fields from environment
|
|
|
|
class S3Service:
|
|
def __init__(self):
|
|
self.settings = Settings()
|
|
self.bucket_name = self.settings.AWS_BUCKET_NAME
|
|
self.s3 = boto3.client(
|
|
's3',
|
|
region_name=self.settings.AWS_REGION,
|
|
aws_access_key_id=self.settings.AWS_ACCESS_KEY,
|
|
aws_secret_access_key=self.settings.AWS_SECRET_KEY,
|
|
config=Config(signature_version='s3v4')
|
|
)
|
|
|
|
|
|
def get_s3_service():
|
|
return S3Service()
|
|
|
|
async def upload_file(
|
|
folder: S3FolderNameEnum,
|
|
file_name: str,
|
|
) -> Dict[str, str]:
|
|
"""
|
|
Generate a pre-signed URL for uploading a file to S3.
|
|
|
|
Args:
|
|
folder: The folder enum to store the file in
|
|
file_name: The name of the file
|
|
Returns:
|
|
Dict containing the URLs and key information
|
|
"""
|
|
s3_service = get_s3_service()
|
|
|
|
timestamp = int(datetime.now().timestamp() * 1000)
|
|
|
|
if folder == S3FolderNameEnum.PROFILE:
|
|
key = f"common/{timestamp}_{file_name}"
|
|
else:
|
|
key = f"assets/{timestamp}_{file_name}"
|
|
|
|
try:
|
|
put_url = s3_service.s3.generate_presigned_url(
|
|
ClientMethod='put_object',
|
|
Params={
|
|
'Bucket': s3_service.bucket_name,
|
|
'Key': key,
|
|
},
|
|
ExpiresIn=s3_service.settings.AWS_S3_EXPIRES
|
|
)
|
|
|
|
get_url = s3_service.s3.generate_presigned_url(
|
|
ClientMethod='get_object',
|
|
Params={
|
|
'Bucket': s3_service.bucket_name,
|
|
'Key': key,
|
|
},
|
|
ExpiresIn=s3_service.settings.AWS_S3_EXPIRES
|
|
)
|
|
|
|
url = urlparse(put_url)
|
|
|
|
return {
|
|
"api_url": put_url,
|
|
"key": key,
|
|
"location": f"{url.scheme}://{url.netloc}/{key}",
|
|
"get_url": get_url,
|
|
}
|
|
except ClientError as e:
|
|
print(f"Error generating pre-signed URL: {e}")
|
|
raise BusinessValidationException(str(e))
|
|
|
|
async def get_signed_url(key: str) -> str:
|
|
"""
|
|
Generate a pre-signed URL for retrieving a file from S3.
|
|
|
|
Args:
|
|
key: The key of the file in S3
|
|
|
|
Returns:
|
|
The pre-signed URL for getting the object
|
|
"""
|
|
s3_service = get_s3_service()
|
|
try:
|
|
url = s3_service.s3.generate_presigned_url(
|
|
ClientMethod='get_object',
|
|
Params={
|
|
'Bucket': s3_service.bucket_name,
|
|
'Key': key,
|
|
},
|
|
ExpiresIn=3600 # 1 hour
|
|
)
|
|
return url
|
|
except ClientError as e:
|
|
print(f"Error in get_signed_url: {e}")
|
|
raise BusinessValidationException(str(e))
|
|
|
|
def get_file_key(url: str) -> str:
|
|
"""
|
|
Extract the file key from a URL or return the key if already provided.
|
|
|
|
Args:
|
|
url: The URL or key
|
|
|
|
Returns:
|
|
The file key
|
|
"""
|
|
try:
|
|
if not url.startswith("http://") and not url.startswith("https://"):
|
|
return url
|
|
|
|
parsed_url = urlparse(url)
|
|
return parsed_url.path.lstrip('/')
|
|
except Exception as e:
|
|
print(f"Error in get_file_key: {e}")
|
|
raise BusinessValidationException(str(e)) |