From e6fbae493bc35e38bcd7f85e5a0a46dce74ce1ee Mon Sep 17 00:00:00 2001 From: deepvasoya Date: Wed, 28 May 2025 15:28:36 +0530 Subject: [PATCH] feat: stripe improvments --- apis/endpoints/stripe.py | 42 +++++++++++------- models/PaymentLogs.py | 4 +- models/PaymentSessions.py | 12 +++++ models/__init__.py | 4 +- services/clinicServices.py | 7 ++- services/stripeServices.py | 91 ++++++++++++++++++++++++++++++++------ services/userServices.py | 3 +- 7 files changed, 125 insertions(+), 38 deletions(-) create mode 100644 models/PaymentSessions.py diff --git a/apis/endpoints/stripe.py b/apis/endpoints/stripe.py index 26b7547..03b6eeb 100644 --- a/apis/endpoints/stripe.py +++ b/apis/endpoints/stripe.py @@ -1,27 +1,35 @@ -from fastapi import APIRouter, Request +from fastapi import APIRouter, Depends, Request from services.stripeServices import StripeServices +from middleware.auth_dependency import auth_required +from schemas.ApiResponse import ApiResponse router = APIRouter() stripe_service = StripeServices() -@router.post("/create-checkout-session") -async def create_checkout_session(user_id: int): - return await stripe_service.create_checkout_session(1) +# @router.post("/create-checkout-session") +# async def create_checkout_session(user_id: int): +# return await stripe_service.create_checkout_session(1) -@router.post("/create-subscription-checkout") -async def create_subscription_checkout(): - return await stripe_service.create_subscription_checkout( - fees_to_be={ - "per_call_charges": 10, - "setup_fees": 100, - "subscription_fees": 100, - "total": 210 - }, - clinic_id=1, - account_id="acct_1RT1UFPTNqn2kWQ8", - customer_id="cus_SNn49FDltUcSLP" - ) +# @router.post("/create-subscription-checkout") +# async def create_subscription_checkout(): +# return await stripe_service.create_subscription_checkout( +# fees_to_be={ +# "per_call_charges": 10, +# "setup_fees": 100, +# "subscription_fees": 100, +# "total": 210 +# }, +# clinic_id=1, +# account_id="acct_1RT1UFPTNqn2kWQ8", +# customer_id="cus_SNn49FDltUcSLP" +# ) + + +@router.post("/create-payment-session", dependencies=[Depends(auth_required)]) +async def create_payment_session(req:Request): + session = await stripe_service.create_payment_session(req.state.user) + return ApiResponse(data=session, message="Payment session created successfully") @router.post("/webhook") async def stripe_webhook(request: Request): diff --git a/models/PaymentLogs.py b/models/PaymentLogs.py index f0b1102..f242d09 100644 --- a/models/PaymentLogs.py +++ b/models/PaymentLogs.py @@ -1,6 +1,6 @@ from database import Base from models.CustomBase import CustomBase -from sqlalchemy import Column, Integer, String, ForeignKey +from sqlalchemy import Column, Integer, String, ForeignKey, Numeric from sqlalchemy.orm import relationship class PaymentLogs(Base, CustomBase): @@ -8,7 +8,7 @@ class PaymentLogs(Base, CustomBase): id = Column(Integer, primary_key=True, index=True) customer_id = Column(String) account_id = Column(String) - amount = Column(Integer) + amount = Column(Numeric(10, 2)) clinic_id = Column(Integer) unique_clinic_id = Column(String) payment_status = Column(String) diff --git a/models/PaymentSessions.py b/models/PaymentSessions.py new file mode 100644 index 0000000..2b2daa9 --- /dev/null +++ b/models/PaymentSessions.py @@ -0,0 +1,12 @@ +from sqlalchemy import Column, Integer, String +from database import Base +from .CustomBase import CustomBase + +class PaymentSessions(Base, CustomBase): + __tablename__ = "payment_sessions" + id = Column(Integer, primary_key=True, index=True) + session_id = Column(String(255), unique=True, index=True) + customer_id = Column(String, nullable=False) + clinic_id = Column(Integer, nullable=False) + status = Column(String, nullable=False) + \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py index 3ed9484..279e551 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -18,6 +18,7 @@ from .ResetPasswordTokens import ResetPasswordTokens from .ClinicOffers import ClinicOffers from .StripeUsers import StripeUsers from .PaymentLogs import PaymentLogs +from .PaymentSessions import PaymentSessions __all__ = [ "Users", @@ -39,5 +40,6 @@ __all__ = [ "ResetPasswordTokens", "ClinicOffers", "StripeUsers", - "PaymentLogs" + "PaymentLogs", + "PaymentSessions" ] diff --git a/services/clinicServices.py b/services/clinicServices.py index 60e40e4..ab8238b 100644 --- a/services/clinicServices.py +++ b/services/clinicServices.py @@ -200,7 +200,8 @@ class ClinicServices: Clinics.status.in_([ ClinicStatus.ACTIVE, ClinicStatus.UNDER_REVIEW, - ClinicStatus.REJECTED + ClinicStatus.REJECTED, + ClinicStatus.PAYMENT_DUE ]) ).group_by(Clinics.status).all() @@ -209,13 +210,15 @@ class ClinicServices: "totalClinics": totalClinics, "totalActiveClinics": 0, "totalUnderReviewClinics": 0, - "totalRejectedClinics": 0 + "totalRejectedClinics": 0, + "totalPaymentDueClinics": 0 } # Map status values to their respective count keys status_to_key = { ClinicStatus.ACTIVE: "totalActiveClinics", ClinicStatus.UNDER_REVIEW: "totalUnderReviewClinics", + ClinicStatus.PAYMENT_DUE: "totalPaymentDueClinics", ClinicStatus.REJECTED: "totalRejectedClinics" } diff --git a/services/stripeServices.py b/services/stripeServices.py index 50a1c5f..ab96766 100644 --- a/services/stripeServices.py +++ b/services/stripeServices.py @@ -2,19 +2,27 @@ import json import os import dotenv + + dotenv.load_dotenv() -from models import Clinics +from models import ClinicOffers, StripeUsers +from services.dashboardService import DashboardService +from exceptions.validation_exception import ValidationException +from exceptions.resource_not_found_exception import ResourceNotFoundException +from exceptions.unauthorized_exception import UnauthorizedException +from models import Clinics,PaymentSessions import uuid from fastapi import Request from datetime import datetime from models import PaymentLogs -from enums.enums import ClinicStatus +from enums.enums import ClinicStatus, UserType from database import get_db from sqlalchemy.orm import Session import stripe from loguru import logger +from decimal import Decimal class StripeServices: @@ -22,6 +30,7 @@ class StripeServices: self.db: Session = next(get_db()) self.logger = logger self.webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET") + self.dashboard_service = DashboardService() async def create_customer(self, user_id: int, email: str, name: str): try: @@ -65,6 +74,50 @@ class StripeServices: self.logger.error(f"Error deleting account: {e}") raise + async def create_payment_session(self, user): + try: + if user["userType"] != UserType.CLINIC_ADMIN: + raise UnauthorizedException("User is not authorized to perform this action") + + clinic = user["created_clinics"][0] + + if clinic["status"] != ClinicStatus.PAYMENT_DUE: + raise ValidationException("Clinic is not due for payment") + + customer = self.db.query(StripeUsers).filter(StripeUsers.user_id == user['id']).first() + + if not customer: + raise ResourceNotFoundException("Customer not found") + + clinic_offers = self.db.query(ClinicOffers).filter(ClinicOffers.clinic_email == clinic["email"]).first() + + signup_pricing= await self.dashboard_service.get_signup_pricing_master() + + fees_to_be = { + "setup_fees": signup_pricing.setup_fees, + "subscription_fees": signup_pricing.subscription_fees, + "per_call_charges": signup_pricing.per_call_charges, + "total": signup_pricing.setup_fees + signup_pricing.subscription_fees + signup_pricing.per_call_charges + } + + if clinic_offers: + fees_to_be["setup_fees"] = clinic_offers.setup_fees + fees_to_be["per_call_charges"] = clinic_offers.per_call_charges + fees_to_be["total"] = clinic_offers.setup_fees + fees_to_be["subscription_fees"] + clinic_offers.per_call_charges + + # remove previouis payment session + self.db.query(PaymentSessions).filter(PaymentSessions.clinic_id == clinic["id"]).delete() + + payment_link = await self.create_subscription_checkout(fees_to_be, clinic["id"], customer.account_id, customer.customer_id) + + return payment_link.url + + except Exception as e: + self.logger.error(f"Error creating payment session: {e}") + raise + finally: + self.db.close() + async def create_checkout_session(self, user_id: int): try: checkout_session = stripe.checkout.Session.create( @@ -118,7 +171,7 @@ class StripeServices: 'product_data': { 'name': 'Monthly Subscription', }, - 'unit_amount': fees_to_be["subscription_fees"], + 'unit_amount': int(fees_to_be["subscription_fees"] * 100), # Convert to cents 'recurring': { 'interval': 'year', }, @@ -132,7 +185,7 @@ class StripeServices: 'product_data': { 'name': 'Per Call', }, - 'unit_amount': fees_to_be["per_call_charges"], + 'unit_amount': int(fees_to_be["per_call_charges"] * 100), # Convert to cents }, 'quantity': 1, }) @@ -143,7 +196,7 @@ class StripeServices: 'product_data': { 'name': 'Setup Fee', }, - 'unit_amount': fees_to_be["setup_fees"], + 'unit_amount': int(fees_to_be["setup_fees"] * 100), # Convert to cents }, 'quantity': 1, }) @@ -152,7 +205,8 @@ class StripeServices: "clinic_id": clinic_id, "unique_clinic_id": unique_clinic_id, "account_id": account_id, - "customer_id": customer_id + "customer_id": customer_id, + "fees_to_be": json.dumps(fees_to_be) } session_data = { @@ -173,13 +227,22 @@ class StripeServices: payment_log = PaymentLogs( customer_id=customer_id, account_id=account_id, - amount=fees_to_be["total"], + amount=Decimal(str(fees_to_be["total"])), # Keep as Decimal for database storage clinic_id=clinic_id, unique_clinic_id=unique_clinic_id, payment_status="pending", metadata_logs=json.dumps(metadata) ) + + new_payment_session = PaymentSessions( + session_id=session.id, + customer_id=customer_id, + clinic_id=clinic_id, + status="pending" + ) + self.db.add(payment_log) + self.db.add(new_payment_session) self.db.commit() return session @@ -201,25 +264,26 @@ class StripeServices: if event["type"] == "invoice.payment_succeeded": pass - # if event["type"] == "payment_intent.succeeded": - # await self.update_payment_log(event["data"]["object"]["metadata"]["unique_clinic_id"]) + if event["type"] == "checkout.session.async_payment_succeeded": + self.logger.info("Async payment succeeded") elif event["type"] == "checkout.session.completed": - await self.update_payment_log(event["data"]["object"]["metadata"]["unique_clinic_id"], event["data"]["object"]["metadata"]["clinic_id"]) + self.update_payment_log(event["data"]["object"]["metadata"]["unique_clinic_id"], event["data"]["object"]["metadata"]["clinic_id"]) + + # TODO: handle subscription period end return event except ValueError as e: self.logger.error(f"Invalid payload: {e}") - raise except stripe.error.SignatureVerificationError as e: self.logger.error(f"Invalid signature: {e}") - raise finally: self.db.close() - async def update_payment_log(self, unique_clinic_id:str, clinic_id:int): + def update_payment_log(self, unique_clinic_id:str, clinic_id:int): try: payment_log = self.db.query(PaymentLogs).filter(PaymentLogs.unique_clinic_id == unique_clinic_id).first() + self.db.query(PaymentSessions).filter(PaymentSessions.clinic_id == clinic_id).delete() if payment_log: payment_log.payment_status = "success" self.db.commit() @@ -231,6 +295,5 @@ class StripeServices: except Exception as e: self.logger.error(f"Error updating payment log: {e}") - raise finally: self.db.close() \ No newline at end of file diff --git a/services/userServices.py b/services/userServices.py index 8e38d0d..4807c0d 100644 --- a/services/userServices.py +++ b/services/userServices.py @@ -137,7 +137,6 @@ class UserServices: # Send mail to admin in a non-blocking way using background tasks if background_tasks: background_tasks.add_task(self._send_emails_to_admins, clinic.email) - # If no background_tasks provided, we don't send emails offer = await self.clinic_service.get_clinic_offer_by_clinic_email(clinic.email) @@ -155,7 +154,7 @@ class UserServices: fees_to_be["per_call_charges"] = offer.per_call_charges fees_to_be["total"] = offer.setup_fees + fees_to_be["subscription_fees"] + offer.per_call_charges - payment_link = await self.stripe_service.create_subscription_checkout(fees_to_be, new_clinic.id, stripe_account.id,stripe_customer.id) + payment_link = await self.stripe_service.create_subscription_checkout(fees_to_be, 1, stripe_account.id,stripe_customer.id) return payment_link.url except Exception as e: