refactor: minor changes
This commit is contained in:
parent
7558b5e2fe
commit
f545a5b75b
|
|
@ -23,16 +23,12 @@ async def get_users(limit:int = DEFAULT_LIMIT, page:int = DEFAULT_PAGE, search:s
|
|||
|
||||
@router.get("/me")
|
||||
async def get_user(request: Request):
|
||||
try:
|
||||
user_id = request.state.user["id"]
|
||||
user = await UserServices().get_user(user_id)
|
||||
return ApiResponse(
|
||||
data=user,
|
||||
message="User fetched successfully"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user: {str(e)}")
|
||||
raise e
|
||||
user_id = request.state.user["id"]
|
||||
user = await UserServices().get_user(user_id)
|
||||
return ApiResponse(
|
||||
data=user,
|
||||
message="User fetched successfully"
|
||||
)
|
||||
|
||||
@router.get("/{user_id}")
|
||||
async def get_user(request: Request, user_id: int):
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class Clinic(ClinicBase):
|
|||
status: ClinicStatus
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ClinicDoctorResponse(ClinicDoctorBase):
|
||||
|
|
@ -43,9 +43,8 @@ class UserResponse(UserBase):
|
|||
created_clinics: Optional[List[Clinic]] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
from_attributes = True
|
||||
allow_population_by_field_name = True
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
class Doctor(DoctorBase):
|
||||
|
|
|
|||
|
|
@ -38,21 +38,28 @@ class AuthService:
|
|||
self.logger = logger
|
||||
|
||||
async def login(self, data: AuthBase) -> str:
|
||||
try:
|
||||
# get user
|
||||
user = await self.user_service.get_user_by_email(data.email)
|
||||
|
||||
# get user
|
||||
user = await self.user_service.get_user_by_email(data.email)
|
||||
# verify password
|
||||
if not verify_password(data.password, user.password):
|
||||
raise UnauthorizedException("Invalid credentials")
|
||||
|
||||
# verify password
|
||||
if not verify_password(data.password, user.password):
|
||||
raise UnauthorizedException("Invalid credentials")
|
||||
user_dict = user.model_dump(
|
||||
exclude={"password": True, "created_clinics": True},
|
||||
exclude_none=True,
|
||||
mode="json"
|
||||
)
|
||||
|
||||
# remove password from user dict
|
||||
user_dict = user.__dict__.copy()
|
||||
user_dict.pop("password", None)
|
||||
|
||||
# create token
|
||||
token = create_jwt_token(user_dict)
|
||||
return token
|
||||
# create token
|
||||
token = create_jwt_token(user_dict)
|
||||
return token
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error logging in: {e}")
|
||||
raise e
|
||||
finally:
|
||||
self.db.close()
|
||||
|
||||
async def register(self, user_data: UserCreate, background_tasks=None):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -5,22 +5,25 @@ from enum import Enum
|
|||
from utils.constants import JWT_SECRET, JWT_ALGORITHM, JWT_EXPIRE_MINUTES
|
||||
|
||||
def create_jwt_token(data: dict):
|
||||
# Create a copy of the data and handle Enum and datetime serialization
|
||||
to_encode = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, Enum):
|
||||
to_encode[key] = value.value # Convert Enum to its string value
|
||||
elif isinstance(value, datetime):
|
||||
to_encode[key] = value.isoformat() # Convert datetime to ISO format string
|
||||
else:
|
||||
to_encode[key] = value
|
||||
try:
|
||||
# Create a copy of the data and handle Enum and datetime serialization
|
||||
to_encode = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, Enum):
|
||||
to_encode[key] = value.value # Convert Enum to its string value
|
||||
elif isinstance(value, datetime):
|
||||
to_encode[key] = value.isoformat() # Convert datetime to ISO format string
|
||||
else:
|
||||
to_encode[key] = value
|
||||
|
||||
# Safely evaluate the JWT_EXPIRE_MINUTES expression
|
||||
minutes = eval(JWT_EXPIRE_MINUTES) if isinstance(JWT_EXPIRE_MINUTES, str) else JWT_EXPIRE_MINUTES
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=minutes)
|
||||
to_encode.update({"exp": expire.timestamp()}) # Use timestamp for expiration
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
# Safely evaluate the JWT_EXPIRE_MINUTES expression
|
||||
minutes = eval(JWT_EXPIRE_MINUTES) if isinstance(JWT_EXPIRE_MINUTES, str) else JWT_EXPIRE_MINUTES
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=minutes)
|
||||
to_encode.update({"exp": expire.timestamp()}) # Use timestamp for expiration
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def verify_jwt_token(token: str):
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ class StripeServices:
|
|||
|
||||
session_data = {
|
||||
'customer': customer_id,
|
||||
'payment_method_types': ['card'],
|
||||
"payment_method_types": ["card","au_becs_debit"],
|
||||
'mode': 'subscription',
|
||||
'line_items': line_items,
|
||||
'success_url': f"{self.redirect_url}auth/waiting",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from sqlalchemy.orm import Session
|
|||
from database import get_db
|
||||
from models.Users import Users
|
||||
from exceptions.validation_exception import ValidationException
|
||||
from schemas.ResponseSchemas import UserResponse
|
||||
from schemas.ResponseSchemas import Clinic, UserResponse
|
||||
from models import Clinics
|
||||
from enums.enums import ClinicStatus, UserType
|
||||
from schemas.UpdateSchemas import UserUpdate
|
||||
|
|
@ -20,6 +20,8 @@ from sqlalchemy.orm import joinedload
|
|||
from services.emailService import EmailService
|
||||
from services.clinicServices import ClinicServices
|
||||
from services.dashboardService import DashboardService
|
||||
|
||||
|
||||
class UserServices:
|
||||
def __init__(self):
|
||||
self.db: Session = next(get_db())
|
||||
|
|
@ -27,6 +29,7 @@ class UserServices:
|
|||
self.stripe_service = StripeServices()
|
||||
self.clinic_service = ClinicServices()
|
||||
self.dashboard_service = DashboardService()
|
||||
self.logger = logger
|
||||
|
||||
async def create_user(self, user_data: UserCreate, background_tasks=None):
|
||||
|
||||
|
|
@ -38,15 +41,11 @@ class UserServices:
|
|||
user = user_data.user
|
||||
# Check if user with same username or email exists
|
||||
existing_user = (
|
||||
self.db.query(Users)
|
||||
.filter(Users.email == user.email.lower())
|
||||
.first()
|
||||
self.db.query(Users).filter(Users.email == user.email.lower()).first()
|
||||
)
|
||||
|
||||
if existing_user:
|
||||
raise ValidationException(
|
||||
"User with same email already exists"
|
||||
)
|
||||
raise ValidationException("User with same email already exists")
|
||||
|
||||
# Create a new user instance
|
||||
new_user = Users(
|
||||
|
|
@ -55,7 +54,7 @@ class UserServices:
|
|||
password=hash_password(user.password),
|
||||
clinicRole=user.clinicRole,
|
||||
userType=user.userType,
|
||||
mobile=user.mobile
|
||||
mobile=user.mobile,
|
||||
)
|
||||
|
||||
# Add user to database but don't commit yet
|
||||
|
|
@ -63,16 +62,20 @@ class UserServices:
|
|||
self.db.flush() # Flush to get the user ID without committing
|
||||
|
||||
# Create stripe customer
|
||||
stripe_customer = await self.stripe_service.create_customer(new_user.id, user.email, user.username)
|
||||
stripe_customer = await self.stripe_service.create_customer(
|
||||
new_user.id, user.email, user.username
|
||||
)
|
||||
|
||||
# Create stripe account
|
||||
stripe_account = await self.stripe_service.create_account(new_user.id, user.email, user.username, user.mobile)
|
||||
stripe_account = await self.stripe_service.create_account(
|
||||
new_user.id, user.email, user.username, user.mobile
|
||||
)
|
||||
|
||||
# Create stripe user
|
||||
stripe_user = StripeUsers(
|
||||
user_id=new_user.id,
|
||||
customer_id=stripe_customer.id,
|
||||
account_id=stripe_account.id
|
||||
account_id=stripe_account.id,
|
||||
)
|
||||
self.db.add(stripe_user)
|
||||
|
||||
|
|
@ -81,8 +84,14 @@ class UserServices:
|
|||
|
||||
# cross verify domain, in db
|
||||
# Convert to lowercase and keep only alphanumeric characters, hyphens, and underscores
|
||||
domain = ''.join(char for char in clinic.name.lower() if char.isalnum() or char == '-' or char == '_')
|
||||
existing_clinic = self.db.query(Clinics).filter(Clinics.domain == domain).first()
|
||||
domain = "".join(
|
||||
char
|
||||
for char in clinic.name.lower()
|
||||
if char.isalnum() or char == "-" or char == "_"
|
||||
)
|
||||
existing_clinic = (
|
||||
self.db.query(Clinics).filter(Clinics.domain == domain).first()
|
||||
)
|
||||
|
||||
if existing_clinic:
|
||||
# This will trigger rollback in the exception handler
|
||||
|
|
@ -114,7 +123,7 @@ class UserServices:
|
|||
voice_model_gender=clinic.voice_model_gender,
|
||||
scenarios=clinic.scenarios,
|
||||
general_info=clinic.general_info,
|
||||
status=ClinicStatus.PAYMENT_DUE, #TODO: change this to PAYMENT_DUE
|
||||
status=ClinicStatus.PAYMENT_DUE, # TODO: change this to PAYMENT_DUE
|
||||
domain=domain,
|
||||
creator_id=new_user.id, # Set the creator_id to link the clinic to the user who created it
|
||||
)
|
||||
|
|
@ -128,7 +137,7 @@ class UserServices:
|
|||
clinic_id=new_clinic.id,
|
||||
abn_doc_is_verified=None,
|
||||
contract_doc_is_verified=None,
|
||||
last_changed_by=new_user.id
|
||||
last_changed_by=new_user.id,
|
||||
)
|
||||
|
||||
# Add clinic files to database
|
||||
|
|
@ -138,7 +147,9 @@ class UserServices:
|
|||
if background_tasks:
|
||||
background_tasks.add_task(self._send_emails_to_admins, clinic.email)
|
||||
|
||||
offer = await self.clinic_service.get_clinic_offer_by_clinic_email(clinic.email)
|
||||
offer = await self.clinic_service.get_clinic_offer_by_clinic_email(
|
||||
clinic.email
|
||||
)
|
||||
|
||||
signup_pricing = await self.dashboard_service.get_signup_pricing_master()
|
||||
|
||||
|
|
@ -146,15 +157,23 @@ class UserServices:
|
|||
"setup_fees": signup_pricing.setup_fees,
|
||||
"subscription_fees": signup_pricing.subscription_fees,
|
||||
"per_call_charges": signup_pricing.per_call_charges,
|
||||
"total": signup_pricing.setup_fees + signup_pricing.subscription_fees + signup_pricing.per_call_charges
|
||||
"total": signup_pricing.setup_fees
|
||||
+ signup_pricing.subscription_fees
|
||||
+ signup_pricing.per_call_charges,
|
||||
}
|
||||
|
||||
if offer:
|
||||
fees_to_be["setup_fees"] = offer.setup_fees
|
||||
fees_to_be["per_call_charges"] = offer.per_call_charges
|
||||
fees_to_be["total"] = offer.setup_fees + fees_to_be["subscription_fees"] + offer.per_call_charges
|
||||
fees_to_be["total"] = (
|
||||
offer.setup_fees
|
||||
+ fees_to_be["subscription_fees"]
|
||||
+ offer.per_call_charges
|
||||
)
|
||||
|
||||
payment_link = await self.stripe_service.create_subscription_checkout(fees_to_be, new_clinic.id, stripe_account.id,stripe_customer.id)
|
||||
payment_link = await self.stripe_service.create_subscription_checkout(
|
||||
fees_to_be, new_clinic.id, stripe_account.id, stripe_customer.id
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
|
|
@ -166,7 +185,7 @@ class UserServices:
|
|||
"clinicRole": new_user.clinicRole,
|
||||
"userType": new_user.userType,
|
||||
"mobile": new_user.mobile,
|
||||
"clinicId": new_clinic.id
|
||||
"clinicId": new_clinic.id,
|
||||
}
|
||||
|
||||
return {
|
||||
|
|
@ -190,33 +209,22 @@ class UserServices:
|
|||
finally:
|
||||
self.db.close()
|
||||
|
||||
|
||||
async def get_user(self, user_id) -> UserResponse:
|
||||
try:
|
||||
# Query the user by ID and explicitly load the created clinics relationship
|
||||
user = self.db.query(Users).options(joinedload(Users.created_clinics)).filter(Users.id == user_id).first()
|
||||
user = (
|
||||
self.db.query(Users)
|
||||
.options(joinedload(Users.created_clinics))
|
||||
.filter(Users.id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user:
|
||||
logger.error("User not found")
|
||||
self.logger.error("User not found")
|
||||
raise ResourceNotFoundException("User not found")
|
||||
|
||||
# First convert the user to a dictionary
|
||||
user_dict = {}
|
||||
for column in user.__table__.columns:
|
||||
user_dict[column.name] = getattr(user, column.name)
|
||||
|
||||
# Convert created clinics to dictionaries
|
||||
if user.created_clinics:
|
||||
clinics_list = []
|
||||
for clinic in user.created_clinics:
|
||||
clinic_dict = {}
|
||||
for column in clinic.__table__.columns:
|
||||
clinic_dict[column.name] = getattr(clinic, column.name)
|
||||
clinics_list.append(clinic_dict)
|
||||
user_dict['created_clinics'] = clinics_list
|
||||
|
||||
# Create the user response
|
||||
user_response = UserResponse.model_validate(user_dict)
|
||||
user_response = UserResponse.model_validate(user)
|
||||
|
||||
# Return the response as a dictionary
|
||||
return user_response.model_dump()
|
||||
|
|
@ -225,7 +233,7 @@ class UserServices:
|
|||
finally:
|
||||
self.db.close()
|
||||
|
||||
async def get_users(self, limit:int, offset:int, search:str):
|
||||
async def get_users(self, limit: int, offset: int, search: str):
|
||||
try:
|
||||
query = self.db.query(Users)
|
||||
if search:
|
||||
|
|
@ -233,16 +241,19 @@ class UserServices:
|
|||
or_(
|
||||
Users.username.contains(search),
|
||||
Users.email.contains(search),
|
||||
Users.clinicRole.contains(search),
|
||||
Users.userType.contains(search)
|
||||
Users.clinicRole.contains(search),
|
||||
Users.userType.contains(search),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
users = query.limit(limit).offset(offset).all()
|
||||
|
||||
total = self.db.query(Users).count()
|
||||
|
||||
response = CommonResponse(data=[UserResponse(**user.__dict__.copy()) for user in users], total=total)
|
||||
response = CommonResponse(
|
||||
data=[UserResponse(**user.__dict__.copy()) for user in users],
|
||||
total=total,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
|
|
@ -252,15 +263,17 @@ class UserServices:
|
|||
|
||||
async def get_user_by_email(self, email: str) -> UserResponse:
|
||||
try:
|
||||
user = self.db.query(Users).filter(Users.email == email.lower()).first()
|
||||
user = (
|
||||
self.db.query(Users)
|
||||
.filter(Users.email == email.lower())
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user:
|
||||
logger.error("User not found")
|
||||
self.logger.error("User not found")
|
||||
raise ResourceNotFoundException("User not found")
|
||||
|
||||
user_dict = user.__dict__.copy()
|
||||
|
||||
user_response = UserResponse(**user_dict)
|
||||
user_response = UserResponse.model_validate(user)
|
||||
|
||||
return user_response
|
||||
except Exception as e:
|
||||
|
|
@ -268,24 +281,28 @@ class UserServices:
|
|||
finally:
|
||||
self.db.close()
|
||||
|
||||
async def update_user(self, admin_id:int|None, user_id: int, user_data: UserUpdate) -> UserResponse:
|
||||
async def update_user(
|
||||
self, admin_id: int | None, user_id: int, user_data: UserUpdate
|
||||
) -> UserResponse:
|
||||
try:
|
||||
# Check admin authorization if admin_id is provided
|
||||
if admin_id:
|
||||
admin = self.db.query(Users).filter(Users.id == admin_id).first()
|
||||
if not admin:
|
||||
logger.error("Admin not found")
|
||||
self.logger.error("Admin not found")
|
||||
raise ResourceNotFoundException("Admin not found")
|
||||
|
||||
# Only check admin type if admin_id was provided
|
||||
if admin.userType != UserType.SUPER_ADMIN:
|
||||
logger.error("User is not authorized to perform this action")
|
||||
raise UnauthorizedException("User is not authorized to perform this action")
|
||||
self.logger.error("User is not authorized to perform this action")
|
||||
raise UnauthorizedException(
|
||||
"User is not authorized to perform this action"
|
||||
)
|
||||
|
||||
# Find the user to update
|
||||
user = self.db.query(Users).filter(Users.id == user_id).first()
|
||||
if not user:
|
||||
logger.error("User not found")
|
||||
self.logger.error("User not found")
|
||||
raise ResourceNotFoundException("User not found")
|
||||
|
||||
# Update only the fields that were provided
|
||||
|
|
@ -309,7 +326,7 @@ class UserServices:
|
|||
user = self.db.query(Users).filter(Users.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
logger.error("User not found")
|
||||
self.logger.error("User not found")
|
||||
raise ResourceNotFoundException("User not found")
|
||||
|
||||
# Use the soft_delete method from CustomBase
|
||||
|
|
@ -323,7 +340,11 @@ class UserServices:
|
|||
|
||||
async def get_super_admins(self):
|
||||
try:
|
||||
return self.db.query(Users).filter(Users.userType == UserType.SUPER_ADMIN).all()
|
||||
return (
|
||||
self.db.query(Users)
|
||||
.filter(Users.userType == UserType.SUPER_ADMIN)
|
||||
.all()
|
||||
)
|
||||
except Exception as e:
|
||||
DBExceptionHandler.handle_exception(e, context="getting super admins")
|
||||
finally:
|
||||
|
|
@ -335,12 +356,11 @@ class UserServices:
|
|||
admins = await self.get_super_admins()
|
||||
for admin in admins:
|
||||
self.email_service.send_new_clinic_email(
|
||||
to_address=admin.email,
|
||||
clinic_name=clinic_name
|
||||
to_address=admin.email, clinic_name=clinic_name
|
||||
)
|
||||
except Exception as e:
|
||||
# Log the error but don't interrupt the main flow
|
||||
logger.error(f"Error sending admin emails: {str(e)}")
|
||||
self.logger.error(f"Error sending admin emails: {str(e)}")
|
||||
finally:
|
||||
self.db.close()
|
||||
|
||||
|
|
@ -349,7 +369,7 @@ class UserServices:
|
|||
user = self.db.query(Users).filter(Users.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
logger.error("User not found")
|
||||
self.logger.error("User not found")
|
||||
raise ResourceNotFoundException("User not found")
|
||||
|
||||
return self.stripe_service.create_payment_link(user_id)
|
||||
|
|
|
|||
Loading…
Reference in New Issue