From 165385358fb05d878634d0718b51dabe3250403c Mon Sep 17 00:00:00 2001 From: deepvasoya Date: Thu, 5 Jun 2025 19:25:33 +0530 Subject: [PATCH] feat: clinic bank details api fix: relation betn stripe and user table --- apis/endpoints/clinicDoctor.py | 3 +- apis/endpoints/dashboard.py | 2 +- apis/endpoints/stripe.py | 4 + .../e50edac1c8f0_updated_stripeuser.py | 38 ++ models/Clinics.py | 5 +- models/StripeUsers.py | 6 +- models/Users.py | 3 +- schemas/BaseSchemas.py | 5 + schemas/ResponseSchemas.py | 39 ++- services/clinicDoctorsServices.py | 19 +- services/dashboardService.py | 20 +- services/stripeServices.py | 329 +++++++++++++----- services/userServices.py | 57 ++- 13 files changed, 381 insertions(+), 149 deletions(-) create mode 100644 migrations/versions/e50edac1c8f0_updated_stripeuser.py diff --git a/apis/endpoints/clinicDoctor.py b/apis/endpoints/clinicDoctor.py index fd3c1c6..d27a7cb 100644 --- a/apis/endpoints/clinicDoctor.py +++ b/apis/endpoints/clinicDoctor.py @@ -10,6 +10,7 @@ router = APIRouter() @router.get("/") async def get_clinic_doctors( + req:Request, limit:int= DEFAULT_LIMIT, page:int = DEFAULT_PAGE, search:str = "", @@ -19,7 +20,7 @@ async def get_clinic_doctors( if page < 1: page = 1 offset = (page - 1) * limit - clinic_doctors = await ClinicDoctorsServices().get_clinic_doctors(limit, offset, search, sort_by, sort_order) + clinic_doctors = await ClinicDoctorsServices().get_clinic_doctors(req.state.user, limit, offset, search, sort_by, sort_order) return ApiResponse(data=clinic_doctors, message="Clinic doctors retrieved successfully") @router.post("/") diff --git a/apis/endpoints/dashboard.py b/apis/endpoints/dashboard.py index ef79c6e..751de7b 100644 --- a/apis/endpoints/dashboard.py +++ b/apis/endpoints/dashboard.py @@ -8,7 +8,7 @@ router = APIRouter() @router.get("/") async def get_clinic_doctor_status_count(req:Request): - counts = await DashboardService().get_dashboard_counts(isSuperAdmin=req.state.user["userType"] == UserType.SUPER_ADMIN) + counts = await DashboardService().get_dashboard_counts(req.state.user) return ApiResponse(data=counts, message="Counts fetched successfully") @router.post("/signup-pricing-master") diff --git a/apis/endpoints/stripe.py b/apis/endpoints/stripe.py index d989d21..e7353c3 100644 --- a/apis/endpoints/stripe.py +++ b/apis/endpoints/stripe.py @@ -25,6 +25,10 @@ stripe_service = StripeServices() # customer_id="cus_SNn49FDltUcSLP" # ) +@router.get("/create-stripe-account-link", dependencies=[Depends(auth_required)]) +async def create_stripe_account_link(req:Request): + link = await stripe_service.create_stripe_account_link(req.state.user) + return ApiResponse(data=link, message="Stripe account link created successfully") @router.get("/get-invoice", dependencies=[Depends(auth_required)]) async def get_invoice(req:Request): diff --git a/migrations/versions/e50edac1c8f0_updated_stripeuser.py b/migrations/versions/e50edac1c8f0_updated_stripeuser.py new file mode 100644 index 0000000..b549198 --- /dev/null +++ b/migrations/versions/e50edac1c8f0_updated_stripeuser.py @@ -0,0 +1,38 @@ +"""updated_stripeuser + +Revision ID: e50edac1c8f0 +Revises: 8d19e726b997 +Create Date: 2025-06-05 18:22:38.502127 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'e50edac1c8f0' +down_revision: Union[str, None] = '8d19e726b997' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('stripe_users', sa.Column('clinic_id', sa.Integer(), nullable=True)) + op.drop_constraint('stripe_users_user_id_key', 'stripe_users', type_='unique') + op.drop_constraint('stripe_users_user_id_fkey', 'stripe_users', type_='foreignkey') + op.create_foreign_key(None, 'stripe_users', 'clinics', ['clinic_id'], ['id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'stripe_users', type_='foreignkey') + op.create_foreign_key('stripe_users_user_id_fkey', 'stripe_users', 'users', ['user_id'], ['id']) + op.create_unique_constraint('stripe_users_user_id_key', 'stripe_users', ['user_id']) + op.drop_column('stripe_users', 'clinic_id') + # ### end Alembic commands ### diff --git a/models/Clinics.py b/models/Clinics.py index ad92ce2..f846661 100644 --- a/models/Clinics.py +++ b/models/Clinics.py @@ -45,4 +45,7 @@ class Clinics(Base, CustomBase): doctors = relationship("Doctors", back_populates="clinic") clinicDoctors = relationship("ClinicDoctors", back_populates="clinic") creator = relationship("Users", back_populates="created_clinics") - clinic_file_verifications = relationship("ClinicFileVerifications", back_populates="clinic") \ No newline at end of file + clinic_file_verifications = relationship("ClinicFileVerifications", back_populates="clinic") + + # Stripe relationships + stripe_user = relationship("StripeUsers", back_populates="clinic") \ No newline at end of file diff --git a/models/StripeUsers.py b/models/StripeUsers.py index c654181..2387b4a 100644 --- a/models/StripeUsers.py +++ b/models/StripeUsers.py @@ -6,8 +6,10 @@ from sqlalchemy.orm import relationship class StripeUsers(Base, CustomBase): __tablename__ = "stripe_users" id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey('users.id'), nullable=False, unique=True) + user_id = Column(Integer, nullable=True) + clinic_id = Column(Integer, ForeignKey('clinics.id'), nullable=True) customer_id = Column(String) account_id = Column(String) - user = relationship("Users", back_populates="stripe_user") \ No newline at end of file + + clinic = relationship("Clinics", back_populates="stripe_user") \ No newline at end of file diff --git a/models/Users.py b/models/Users.py index 081b942..db6e31a 100644 --- a/models/Users.py +++ b/models/Users.py @@ -27,5 +27,4 @@ class Users(Base, CustomBase): created_clinics = relationship("Clinics", back_populates="creator") clinic_file_verifications = relationship("ClinicFileVerifications", back_populates="last_changed_by_user") - # Stripe relationships - stripe_user = relationship("StripeUsers", back_populates="user") \ No newline at end of file + # No longer have Stripe relationships \ No newline at end of file diff --git a/schemas/BaseSchemas.py b/schemas/BaseSchemas.py index fa9869f..e11d388 100644 --- a/schemas/BaseSchemas.py +++ b/schemas/BaseSchemas.py @@ -139,3 +139,8 @@ class ClinicOffersBase(BaseModel): setup_fees_waived: bool special_offer_for_month: str + +class StripeUserBase(BaseModel): + account_id: str + customer_id: str + user_id: int \ No newline at end of file diff --git a/schemas/ResponseSchemas.py b/schemas/ResponseSchemas.py index a926a6c..b3cc3b0 100644 --- a/schemas/ResponseSchemas.py +++ b/schemas/ResponseSchemas.py @@ -23,7 +23,7 @@ class ClinicDoctorResponse(ClinicDoctorBase): update_time: datetime class Config: - orm_mode = True + from_attributes = True class SignupPricingMasterResponse(SignupPricingMasterBase): @@ -32,7 +32,7 @@ class SignupPricingMasterResponse(SignupPricingMasterBase): update_time: datetime class Config: - orm_mode = True + from_attributes = True class UserResponse(UserBase): @@ -53,7 +53,7 @@ class Doctor(DoctorBase): update_time: datetime class Config: - orm_mode = True + from_attributes = True class Patient(PatientBase): @@ -62,7 +62,7 @@ class Patient(PatientBase): update_time: datetime class Config: - orm_mode = True + from_attributes = True class AppointmentSchema(AppointmentBase): @@ -71,7 +71,7 @@ class AppointmentSchema(AppointmentBase): update_time: datetime class Config: - orm_mode = True + from_attributes = True class Calendar(CalendarBase): @@ -80,7 +80,7 @@ class Calendar(CalendarBase): update_time: datetime class Config: - orm_mode = True + from_attributes = True # custom schema for response @@ -88,7 +88,7 @@ class CalendarTimeSchema(BaseModel): time: str class Config: - orm_mode = True + from_attributes = True class ClinicSchema(BaseModel): @@ -99,7 +99,7 @@ class ClinicSchema(BaseModel): email: str class Config: - orm_mode = True + from_attributes = True # Detailed response schemas with nested relationships @@ -133,7 +133,7 @@ class AppointmentDetailed(AppointmentSchema): address: str class Config: - orm_mode = True + from_attributes = True class Patient(BaseModel): id: int @@ -145,7 +145,7 @@ class AppointmentDetailed(AppointmentSchema): dob: str class Config: - orm_mode = True + from_attributes = True doctor: Doctor patient: Patient @@ -157,7 +157,7 @@ class CallTranscriptsResponse(CallTranscriptsBase): update_time: datetime class Config: - orm_mode = True + from_attributes = True class NotificationResponse(NotificationBase): @@ -166,7 +166,7 @@ class NotificationResponse(NotificationBase): update_time: datetime class Config: - orm_mode = True + from_attributes = True class MasterAppointmentTypeResponse(MasterAppointmentTypeBase): @@ -175,7 +175,7 @@ class MasterAppointmentTypeResponse(MasterAppointmentTypeBase): update_time: datetime class Config: - orm_mode = True + from_attributes = True class ClinicDoctorResponse(ClinicDoctorBase): @@ -185,7 +185,6 @@ class ClinicDoctorResponse(ClinicDoctorBase): appointmentTypes: Optional[List[MasterAppointmentTypeResponse]] = [] class Config: - orm_mode = True from_attributes = True allow_population_by_field_name = True @@ -196,4 +195,14 @@ class ClinicOfferResponse(ClinicOffersBase): update_time: datetime class Config: - orm_mode = True + from_attributes = True + + + +class StripeUserReponse(StripeUserBase): + id: int + create_time: datetime + update_time: datetime + + class Config: + from_attributes = True \ No newline at end of file diff --git a/services/clinicDoctorsServices.py b/services/clinicDoctorsServices.py index 6d30e63..96af948 100644 --- a/services/clinicDoctorsServices.py +++ b/services/clinicDoctorsServices.py @@ -3,14 +3,13 @@ from schemas.CreateSchemas import ClinicDoctorCreate from schemas.UpdateSchemas import ClinicDoctorUpdate from schemas.ResponseSchemas import ClinicDoctorResponse, MasterAppointmentTypeResponse from database import get_db -from models import ClinicDoctors from sqlalchemy.orm import Session, joinedload, selectinload from services.clinicServices import ClinicServices from exceptions import ResourceNotFoundException from interface.common_response import CommonResponse from sqlalchemy import func, or_, cast, String from enums.enums import ClinicDoctorStatus, UserType -from models import MasterAppointmentTypes, AppointmentRelations +from models import MasterAppointmentTypes, AppointmentRelations, Users, ClinicDoctors from utils.constants import DEFAULT_ORDER, DEFAULT_ORDER_BY @@ -174,13 +173,14 @@ class ClinicDoctorsServices: finally: self.db.close() - async def get_doctor_status_count(self): + async def get_doctor_status_count(self, clinic_id:int): try: # Query to count doctors by status status_counts = ( self.db.query( ClinicDoctors.status, func.count(ClinicDoctors.id).label("count") ) + .filter(ClinicDoctors.clinic_id == clinic_id) .group_by(ClinicDoctors.status) .all() ) @@ -198,17 +198,18 @@ class ClinicDoctorsServices: finally: self.db.close() - async def get_clinic_doctors(self, limit: int, offset: int, search: str = "", sort_by: str = DEFAULT_ORDER, sort_order: str = DEFAULT_ORDER_BY): + async def get_clinic_doctors(self,user, limit: int, offset: int, search: str = "", sort_by: str = DEFAULT_ORDER, sort_order: str = DEFAULT_ORDER_BY): try: clinic_doctors_query = ( self.db.query(ClinicDoctors) + .filter(ClinicDoctors.clinic_id == user["created_clinics"][0]["id"]) .options( selectinload(ClinicDoctors.appointmentRelations) .selectinload(AppointmentRelations.masterAppointmentTypes) ) .order_by( - getattr(ClinicDoctors, sort_by).desc() - if sort_order == "desc" + getattr(ClinicDoctors, sort_by).desc() + if sort_order == "desc" else getattr(ClinicDoctors, sort_by).asc() ) ) @@ -230,7 +231,7 @@ class ClinicDoctorsServices: total = clinic_doctors_query.count() clinic_doctors = clinic_doctors_query.limit(limit).offset(offset).all() - + # Build response data manually to include appointment types response_data = [] for clinic_doctor in clinic_doctors: @@ -246,7 +247,7 @@ class ClinicDoctorsServices: update_time=relation.masterAppointmentTypes.update_time ) ) - + # Create the clinic doctor response clinic_doctor_data = ClinicDoctorResponse( id=clinic_doctor.id, @@ -258,7 +259,7 @@ class ClinicDoctorsServices: appointmentTypes=appointment_types ) response_data.append(clinic_doctor_data) - + response = CommonResponse( data=response_data, total=total, diff --git a/services/dashboardService.py b/services/dashboardService.py index 2532b7b..7b96375 100644 --- a/services/dashboardService.py +++ b/services/dashboardService.py @@ -8,6 +8,8 @@ from exceptions import UnauthorizedException from enums.enums import UserType from exceptions import ResourceNotFoundException from loguru import logger + +from models import Users class DashboardService: def __init__(self): self.db = next(get_db()) @@ -15,13 +17,17 @@ class DashboardService: self.clinicServices = ClinicServices() self.logger = logger - async def get_dashboard_counts(self, isSuperAdmin: bool): - if isSuperAdmin: - clinicCounts = await self.clinicServices.get_clinic_count() - return clinicCounts - else: - clinicDoctorsCount = await self.clinicDoctorsServices.get_doctor_status_count() - return clinicDoctorsCount + async def get_dashboard_counts(self, user): + try: + if user["userType"] == UserType.SUPER_ADMIN: + clinicCounts = await self.clinicServices.get_clinic_count() + return clinicCounts + else: + clinicDoctorsCount = await self.clinicDoctorsServices.get_doctor_status_count(user["created_clinics"][0]["id"]) + return clinicDoctorsCount + except Exception as e: + self.logger.error("Error getting dashboard counts: ", e) + raise e async def update_signup_pricing_master( self, user, pricing_data: SignupPricingMasterBase diff --git a/services/stripeServices.py b/services/stripeServices.py index 692c6a2..c9ca351 100644 --- a/services/stripeServices.py +++ b/services/stripeServices.py @@ -3,7 +3,6 @@ import os import dotenv - dotenv.load_dotenv() from models import ClinicOffers, StripeUsers @@ -11,12 +10,13 @@ from services.dashboardService import DashboardService from exceptions.validation_exception import ValidationException from exceptions.resource_not_found_exception import ResourceNotFoundException from exceptions.unauthorized_exception import UnauthorizedException -from models import Clinics,PaymentSessions, Subscriptions +from models import Clinics, PaymentSessions, Subscriptions import uuid from fastapi import Request from datetime import datetime from models import PaymentLogs from enums.enums import ClinicStatus, UserType +from schemas.ResponseSchemas import StripeUserReponse from database import get_db from sqlalchemy.orm import Session @@ -75,28 +75,110 @@ class StripeServices: self.logger.error(f"Error deleting account: {e}") raise + async def get_stripe_data(self, clinic_id: int): + try: + user = ( + self.db.query(StripeUsers) + .filter(StripeUsers.clinic_id == clinic_id) + .first() + ) + + if not user: + self.logger.error(f"User not found!") + raise ResourceNotFoundException("User not found!") + + return StripeUserReponse.model_validate(user).model_dump() + except Exception as e: + self.logger.error(f"Error retrieving account data: {e}") + raise + finally: + self.db.close() + + async def create_stripe_account_link(self, user): + try: + stripe_account = await self.get_stripe_data(user["created_clinics"][0]["id"]) + + if not stripe_account: + self.logger.error("Stripe account not found!") + raise ResourceNotFoundException("Stripe account not found!") + + # Pass the account_id as a string, not as a dictionary + data = await stripe.AccountLink.create_async( + account=stripe_account["account_id"], + refresh_url=self.redirect_url, + return_url=self.redirect_url, + type="account_onboarding", + ) + + return data.url + + except Exception as e: + self.logger.error(f"Error creating stripe account link: {e}") + raise + + async def check_account_capabilities(self, user): + try: + stripe_account = await self.get_stripe_data(user["created_clinics"][0]["id"]) + + if not stripe_account: + self.logger.error("Stripe account not found!") + raise ResourceNotFoundException("Stripe account not found!") + + data = await stripe.Account.retrieve_async(stripe_account["account_id"]) + + return { + "capabilities": data.capabilities, + "charges_enabled": data.charges_enabled, + "requirements": data.requirements.currently_due, + "error": data.requirements.errors, + } + + except Exception as e: + self.logger.error(f"Error checking stripe account capabilities: {e}") + raise + finally: + self.db.close() + async def get_invoice(self, user): try: if user["userType"] != UserType.CLINIC_ADMIN: - raise UnauthorizedException("User is not authorized to perform this action") + raise UnauthorizedException( + "User is not authorized to perform this action" + ) - clinic = self.db.query(Clinics).filter(Clinics.creator_id == user["id"]).first() + clinic = ( + self.db.query(Clinics).filter(Clinics.creator_id == user["id"]).first() + ) if not clinic: raise ResourceNotFoundException("Clinic not found!") - customer = self.db.query(StripeUsers).filter(StripeUsers.user_id == user["id"]).first() + customer = ( + self.db.query(StripeUsers) + .filter(StripeUsers.user_id == user["id"]) + .first() + ) if not customer: raise ResourceNotFoundException("Customer not found!") - subscription = self.db.query(Subscriptions).filter(Subscriptions.clinic_id == clinic.id, Subscriptions.customer_id == customer.customer_id, Subscriptions.status == "active").first() + subscription = ( + self.db.query(Subscriptions) + .filter( + Subscriptions.clinic_id == clinic.id, + Subscriptions.customer_id == customer.customer_id, + Subscriptions.status == "active", + ) + .first() + ) if not subscription: raise ResourceNotFoundException("Subscription not found!") - stripe_subscription = await stripe.Subscription.retrieve_async(subscription.subscription_id) + stripe_subscription = await stripe.Subscription.retrieve_async( + subscription.subscription_id + ) invoice = await stripe.Invoice.retrieve_async( stripe_subscription["latest_invoice"] @@ -114,36 +196,54 @@ class StripeServices: async def create_payment_session(self, user): try: if user["userType"] != UserType.CLINIC_ADMIN: - raise UnauthorizedException("User is not authorized to perform this action") - + raise UnauthorizedException( + "User is not authorized to perform this action" + ) + clinic = user["created_clinics"][0] if clinic["status"] != ClinicStatus.PAYMENT_DUE: raise ValidationException("Clinic is not due for payment") - customer = self.db.query(StripeUsers).filter(StripeUsers.user_id == user['id']).first() + customer = ( + self.db.query(StripeUsers) + .filter(StripeUsers.user_id == user["id"]) + .first() + ) if not customer: raise ResourceNotFoundException("Customer not found") - clinic_offers = self.db.query(ClinicOffers).filter(ClinicOffers.clinic_email == clinic["email"]).first() + clinic_offers = ( + self.db.query(ClinicOffers) + .filter(ClinicOffers.clinic_email == clinic["email"]) + .first() + ) - signup_pricing= await self.dashboard_service.get_signup_pricing_master() + signup_pricing = await self.dashboard_service.get_signup_pricing_master() fees_to_be = { "setup_fees": signup_pricing.setup_fees, "subscription_fees": signup_pricing.subscription_fees, "per_call_charges": signup_pricing.per_call_charges, - "total": signup_pricing.setup_fees + signup_pricing.subscription_fees + signup_pricing.per_call_charges + "total": signup_pricing.setup_fees + + signup_pricing.subscription_fees + + signup_pricing.per_call_charges, } if clinic_offers: fees_to_be["setup_fees"] = clinic_offers.setup_fees fees_to_be["per_call_charges"] = clinic_offers.per_call_charges - fees_to_be["total"] = clinic_offers.setup_fees + fees_to_be["subscription_fees"] + clinic_offers.per_call_charges - - payment_link = await self.create_subscription_checkout(fees_to_be, clinic["id"], customer.account_id, customer.customer_id) - + fees_to_be["total"] = ( + clinic_offers.setup_fees + + fees_to_be["subscription_fees"] + + clinic_offers.per_call_charges + ) + + payment_link = await self.create_subscription_checkout( + fees_to_be, clinic["id"], customer.account_id, customer.customer_id + ) + return payment_link.url except Exception as e: @@ -193,67 +293,79 @@ class StripeServices: self.logger.error(f"Error creating setup intent: {e}") raise - async def create_subscription_checkout(self, fees_to_be: dict, clinic_id: int, account_id: str, customer_id: str): + async def create_subscription_checkout( + self, fees_to_be: dict, clinic_id: int, account_id: str, customer_id: str + ): try: unique_id = str(uuid.uuid4()) unique_clinic_id = f"clinic_{clinic_id}_{unique_id}" - line_items = [{ - 'price_data': { - 'currency': 'aud', - 'product_data': { - 'name': 'Monthly Subscription', + line_items = [ + { + "price_data": { + "currency": "aud", + "product_data": { + "name": "Monthly Subscription", + }, + "unit_amount": int( + fees_to_be["subscription_fees"] * 100 + ), # Convert to cents + "recurring": { + "interval": "year", + }, }, - 'unit_amount': int(fees_to_be["subscription_fees"] * 100), # Convert to cents - 'recurring': { - 'interval': 'year', - }, - }, - 'quantity': 1, - }] + "quantity": 1, + } + ] - line_items.append({ - 'price_data': { - 'currency': 'aud', - 'product_data': { - 'name': 'Per Call', + line_items.append( + { + "price_data": { + "currency": "aud", + "product_data": { + "name": "Per Call", + }, + "unit_amount": int( + fees_to_be["per_call_charges"] * 100 + ), # Convert to cents }, - 'unit_amount': int(fees_to_be["per_call_charges"] * 100), # Convert to cents - }, - 'quantity': 1, - }) + "quantity": 1, + } + ) - line_items.append({ - 'price_data': { - 'currency': 'aud', - 'product_data': { - 'name': 'Setup Fee', + line_items.append( + { + "price_data": { + "currency": "aud", + "product_data": { + "name": "Setup Fee", + }, + "unit_amount": int( + fees_to_be["setup_fees"] * 100 + ), # Convert to cents }, - 'unit_amount': int(fees_to_be["setup_fees"] * 100), # Convert to cents - }, - 'quantity': 1, - }) + "quantity": 1, + } + ) metadata = { "clinic_id": clinic_id, "unique_clinic_id": unique_clinic_id, "account_id": account_id, "customer_id": customer_id, - "fees_to_be": json.dumps(fees_to_be) + "fees_to_be": json.dumps(fees_to_be), } session_data = { - 'customer': customer_id, - "payment_method_types": ["card","au_becs_debit"], - 'mode': 'subscription', - 'line_items': line_items, - 'success_url': f"{self.redirect_url}auth/waiting", - 'cancel_url': f"{self.redirect_url}auth/waiting", - 'metadata': metadata, - 'subscription_data': { - 'metadata': metadata - } + "customer": customer_id, + "payment_method_types": ["card", "au_becs_debit"], + "mode": "subscription", + "line_items": line_items, + "success_url": f"{self.redirect_url}auth/waiting", + "cancel_url": f"{self.redirect_url}auth/waiting", + "metadata": metadata, + "subscription_data": {"metadata": metadata}, } session = await stripe.checkout.Session.create_async(**session_data) @@ -261,18 +373,20 @@ class StripeServices: payment_log = PaymentLogs( customer_id=customer_id, account_id=account_id, - amount=Decimal(str(fees_to_be["total"])), # Keep as Decimal for database storage + amount=Decimal( + str(fees_to_be["total"]) + ), # Keep as Decimal for database storage clinic_id=clinic_id, unique_clinic_id=unique_clinic_id, payment_status="pending", - metadata_logs=json.dumps(metadata) + metadata_logs=json.dumps(metadata), ) new_payment_session = PaymentSessions( session_id=session.id, customer_id=customer_id, clinic_id=clinic_id, - status="pending" + status="pending", ) self.db.add(payment_log) @@ -296,12 +410,16 @@ class StripeServices: if event["type"] == "customer.subscription.deleted": self.logger.info("customer subscription ended") - subscription_id = event["data"]["object"]["items"]["data"][0]["subscription"] - + subscription_id = event["data"]["object"]["items"]["data"][0][ + "subscription" + ] + await self._subscription_expired(subscription_id) if event["type"] == "checkout.session.completed": - unique_clinic_id = event["data"]["object"]["metadata"]["unique_clinic_id"] + unique_clinic_id = event["data"]["object"]["metadata"][ + "unique_clinic_id" + ] clinic_id = event["data"]["object"]["metadata"]["clinic_id"] customer_id = event["data"]["object"]["metadata"]["customer_id"] account_id = event["data"]["object"]["metadata"]["account_id"] @@ -310,15 +428,24 @@ class StripeServices: session_id = event["data"]["object"]["id"] subscription_id = event["data"]["object"]["subscription"] - await self._update_payment_log(unique_clinic_id, clinic_id, customer_id, account_id, total, metadata) + await self._update_payment_log( + unique_clinic_id, + clinic_id, + customer_id, + account_id, + total, + metadata, + ) - await self._create_subscription_entry({ - "clinic_id": clinic_id, - "customer_id": customer_id, - "account_id": account_id, - "session_id": session_id, - "subscription_id": subscription_id, - }) + await self._create_subscription_entry( + { + "clinic_id": clinic_id, + "customer_id": customer_id, + "account_id": account_id, + "session_id": session_id, + "subscription_id": subscription_id, + } + ) # TODO: handle subscription period end return "OK" @@ -329,9 +456,19 @@ class StripeServices: finally: self.db.close() - async def _update_payment_log(self, unique_clinic_id:str, clinic_id:int, customer_id:str, account_id:str, total:float, metadata:any): + async def _update_payment_log( + self, + unique_clinic_id: str, + clinic_id: int, + customer_id: str, + account_id: str, + total: float, + metadata: any, + ): try: - self.db.query(PaymentSessions).filter(PaymentSessions.clinic_id == clinic_id).delete() + self.db.query(PaymentSessions).filter( + PaymentSessions.clinic_id == clinic_id + ).delete() payment_log = PaymentLogs( customer_id=customer_id, @@ -340,12 +477,12 @@ class StripeServices: clinic_id=clinic_id, unique_clinic_id=unique_clinic_id, payment_status="paid", - metadata_logs=json.dumps(metadata.to_dict()) + metadata_logs=json.dumps(metadata.to_dict()), ) self.db.add(payment_log) - + clinic = self.db.query(Clinics).filter(Clinics.id == clinic_id).first() - + if clinic: clinic.status = ClinicStatus.UNDER_REVIEW self.db.add(clinic) @@ -356,12 +493,12 @@ class StripeServices: self.db.commit() self.db.close() - async def _create_subscription_entry(self,data:dict): + async def _create_subscription_entry(self, data: dict): try: subscription = stripe.Subscription.retrieve(data["subscription_id"]) - metadata_dict = json.loads(subscription.metadata) + metadata_dict = subscription.metadata fees_to_be = json.loads(metadata_dict["fees_to_be"]) new_subscription = Subscriptions( @@ -374,9 +511,13 @@ class StripeServices: per_call_charge=fees_to_be["per_call_charges"], subscription_id=data["subscription_id"], status=subscription.status, - current_period_start=subscription["items"]["data"][0]["current_period_start"], - current_period_end=subscription["items"]["data"][0]["current_period_end"], - metadata_logs=json.dumps(subscription.metadata) + current_period_start=subscription["items"]["data"][0][ + "current_period_start" + ], + current_period_end=subscription["items"]["data"][0][ + "current_period_end" + ], + metadata_logs=json.dumps(subscription.metadata), ) self.db.add(new_subscription) @@ -384,7 +525,7 @@ class StripeServices: session_id=data["session_id"], customer_id=data["customer_id"], clinic_id=data["clinic_id"], - status="paid" + status="paid", ) self.db.add(payment_session) return @@ -393,26 +534,30 @@ class StripeServices: finally: self.db.commit() self.db.close() - - async def _subscription_expired(self,subscription_id): + + async def _subscription_expired(self, subscription_id): try: subscription = stripe.Subscription.retrieve(subscription_id) - - db_subscription = self.db.query(Subscriptions).filter(Subscriptions.subscription_id == subscription_id).first() - + + db_subscription = ( + self.db.query(Subscriptions) + .filter(Subscriptions.subscription_id == subscription_id) + .first() + ) + if not db_subscription: self.logger.error("Subscription not found!") raise Exception("Subscription not found!") - + db_subscription.status = subscription.status self.db.add(db_subscription) # TODO: update clinic status # TODO: send email to user - + return except Exception as e: self.logger.error(f"Error ending subscription: {e}") finally: self.db.commit() - self.db.close() \ No newline at end of file + self.db.close() diff --git a/services/userServices.py b/services/userServices.py index e86d9ca..6378eae 100644 --- a/services/userServices.py +++ b/services/userServices.py @@ -1,5 +1,6 @@ import asyncio from loguru import logger +from sqlalchemy import or_ from sqlalchemy.orm import Session from database import get_db @@ -62,23 +63,6 @@ class UserServices: self.db.add(new_user) self.db.flush() # Flush to get the user ID without committing - stripe_customer, stripe_account = await asyncio.gather( - self.stripe_service.create_customer( - new_user.id, user.email, user.username - ), - self.stripe_service.create_account( - new_user.id, user.email, user.username, user.mobile - ), - ) - - # Create stripe user - stripe_user = StripeUsers( - user_id=new_user.id, - customer_id=stripe_customer.id, - account_id=stripe_account.id, - ) - self.db.add(stripe_user) - # Get clinic data clinic = user_data.clinic @@ -90,18 +74,35 @@ class UserServices: if char.isalnum() or char == "-" or char == "_" ) existing_clinic = ( - self.db.query(Clinics).filter(Clinics.domain == domain).first() + self.db.query(Clinics).filter( + or_(Clinics.domain == domain, + Clinics.email == clinic.email, + Clinics.phone == clinic.phone, + Clinics.emergency_phone == clinic.emergency_phone, + Clinics.abn_number == clinic.abn_number, + ) + ).first() ) if existing_clinic: # This will trigger rollback in the exception handler - raise ValidationException("Clinic with same domain already exists") + if existing_clinic.domain == domain: + raise ValidationException("Clinic with same name already exists") + if existing_clinic.email == clinic.email: + raise ValidationException("Clinic with same email already exists") + if existing_clinic.phone == clinic.phone: + raise ValidationException("Clinic with same phone already exists") + if existing_clinic.emergency_phone == clinic.emergency_phone: + raise ValidationException("Clinic with same emergency phone already exists") + if existing_clinic.abn_number == clinic.abn_number: + raise ValidationException("Clinic with same ABN already exists") # Create clinic instance new_clinic = Clinics( name=clinic.name, address=clinic.address, phone=clinic.phone, + emergency_phone=clinic.emergency_phone, email=clinic.email, integration=clinic.integration, pms_id=clinic.pms_id, @@ -151,6 +152,24 @@ class UserServices: clinic.email ) + + stripe_customer, stripe_account = await asyncio.gather( + self.stripe_service.create_customer( + new_user.id, user.email, user.username + ), + self.stripe_service.create_account( + new_user.id, user.email, user.username, user.mobile + ), + ) + + # Create stripe user + stripe_user = StripeUsers( + user_id=new_user.id, + customer_id=stripe_customer.id, + account_id=stripe_account.id, + ) + self.db.add(stripe_user) + signup_pricing = await self.dashboard_service.get_signup_pricing_master() fees_to_be = {