diff --git a/apis/__init__.py b/apis/__init__.py index fdc326e..41dbcf8 100644 --- a/apis/__init__.py +++ b/apis/__init__.py @@ -2,12 +2,10 @@ from fastapi import APIRouter, Depends, Security from middleware.auth_dependency import auth_required from fastapi.security import HTTPBearer -from apis.endpoints import sns - # Import the security scheme bearer_scheme = HTTPBearer(scheme_name="Bearer Authentication") -from .endpoints import clinics, doctors, calender, appointments, patients, admin, auth, s3, users, clinicDoctor, dashboard, call_transcripts, notifications,sns +from .endpoints import clinics, doctors, calender, appointments, patients, admin, auth, s3, users, clinicDoctor, dashboard, call_transcripts, notifications,sns, stripe api_router = APIRouter() # api_router.include_router(twilio.router, prefix="/twilio") @@ -23,6 +21,8 @@ api_router.include_router(patients.router, prefix="/patients", tags=["patients"] api_router.include_router(sns.router, prefix="/sns", tags=["sns"], include_in_schema=False) +api_router.include_router(stripe.router, prefix="/stripe", tags=["stripe"]) + api_router.include_router( admin.router, prefix="/admin", diff --git a/apis/endpoints/auth.py b/apis/endpoints/auth.py index 170f1c9..e12bb7e 100644 --- a/apis/endpoints/auth.py +++ b/apis/endpoints/auth.py @@ -18,9 +18,9 @@ async def login(data: AuthBase): @router.post("/register") async def register(user_data: UserCreate, background_tasks: BackgroundTasks): - token = await AuthService().register(user_data, background_tasks) + response = await AuthService().register(user_data, background_tasks) return ApiResponse( - data=token, + data=response, message="User registered successfully" ) diff --git a/apis/endpoints/stripe.py b/apis/endpoints/stripe.py new file mode 100644 index 0000000..26b7547 --- /dev/null +++ b/apis/endpoints/stripe.py @@ -0,0 +1,28 @@ +from fastapi import APIRouter, Request +from services.stripeServices import StripeServices + +router = APIRouter() + +stripe_service = StripeServices() + +@router.post("/create-checkout-session") +async def create_checkout_session(user_id: int): + return await stripe_service.create_checkout_session(1) + +@router.post("/create-subscription-checkout") +async def create_subscription_checkout(): + return await stripe_service.create_subscription_checkout( + fees_to_be={ + "per_call_charges": 10, + "setup_fees": 100, + "subscription_fees": 100, + "total": 210 + }, + clinic_id=1, + account_id="acct_1RT1UFPTNqn2kWQ8", + customer_id="cus_SNn49FDltUcSLP" + ) + +@router.post("/webhook") +async def stripe_webhook(request: Request): + return await stripe_service.handle_webhook(request) \ No newline at end of file diff --git a/database.py b/database.py index 3e100eb..b6d2337 100644 --- a/database.py +++ b/database.py @@ -1,10 +1,11 @@ import dotenv +dotenv.load_dotenv() + from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker import os -dotenv.load_dotenv() engine = create_engine( os.getenv("DB_URL"), @@ -13,6 +14,7 @@ engine = create_engine( max_overflow=10, # Max extra connections when pool is full pool_recycle=3600, # Recycle connections after 1 hour echo=True, # Log SQL queries + connect_args={"sslmode": "require" if os.getenv("IS_DEV") == "False" else "disable"}, ) Base = declarative_base() # Base class for ORM models diff --git a/main.py b/main.py index 01c2024..b71dab0 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,13 @@ +import os import dotenv +dotenv.load_dotenv() + from fastapi import FastAPI from contextlib import asynccontextmanager import logging from fastapi.middleware.cors import CORSMiddleware from fastapi.security import HTTPBearer +import stripe # db from database import Base, engine @@ -16,7 +20,6 @@ from middleware.ErrorHandlerMiddleware import ErrorHandlerMiddleware, configure_ from middleware.CustomRequestTypeMiddleware import TextPlainMiddleware from services.emailService import EmailService -dotenv.load_dotenv() # Configure logging logging.basicConfig( @@ -26,12 +29,34 @@ logging.basicConfig( logger = logging.getLogger(__name__) +STRIPE_SECRET_KEY = os.getenv("STRIPE_SECRET_KEY") +STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET") + @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Starting application") try: Base.metadata.create_all(bind=engine) logger.info("Created database tables") + + if STRIPE_SECRET_KEY is None or STRIPE_WEBHOOK_SECRET is None: + raise ValueError("Stripe API key or webhook secret is not set") + + stripe.api_key = STRIPE_SECRET_KEY + stripe.api_version = "2025-04-30.basil" + logger.info("Stripe API key set") + + # Test Stripe connection + try: + account = stripe.Account.retrieve() + logger.info(f"Stripe connection verified - Account ID: {account.id}") + except stripe.error.AuthenticationError as e: + logger.error(f"Stripe authentication failed: {e}") + raise + except stripe.error.StripeError as e: + logger.error(f"Stripe connection test failed: {e}") + raise + except Exception as e: logger.error(f"Error creating database tables: {e}") raise e diff --git a/models/PaymentLogs.py b/models/PaymentLogs.py new file mode 100644 index 0000000..f0b1102 --- /dev/null +++ b/models/PaymentLogs.py @@ -0,0 +1,16 @@ +from database import Base +from models.CustomBase import CustomBase +from sqlalchemy import Column, Integer, String, ForeignKey +from sqlalchemy.orm import relationship + +class PaymentLogs(Base, CustomBase): + __tablename__ = "payment_logs" + id = Column(Integer, primary_key=True, index=True) + customer_id = Column(String) + account_id = Column(String) + amount = Column(Integer) + clinic_id = Column(Integer) + unique_clinic_id = Column(String) + payment_status = Column(String) + metadata_logs = Column(String) + \ No newline at end of file diff --git a/models/StripeUsers.py b/models/StripeUsers.py new file mode 100644 index 0000000..c654181 --- /dev/null +++ b/models/StripeUsers.py @@ -0,0 +1,13 @@ +from database import Base +from models.CustomBase import CustomBase +from sqlalchemy import Column, Integer, String, ForeignKey +from sqlalchemy.orm import relationship + +class StripeUsers(Base, CustomBase): + __tablename__ = "stripe_users" + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey('users.id'), nullable=False, unique=True) + customer_id = Column(String) + account_id = Column(String) + + user = relationship("Users", back_populates="stripe_user") \ No newline at end of file diff --git a/models/Users.py b/models/Users.py index ba3e76c..081b942 100644 --- a/models/Users.py +++ b/models/Users.py @@ -25,4 +25,7 @@ class Users(Base, CustomBase): # Clinics created by this user created_clinics = relationship("Clinics", back_populates="creator") - clinic_file_verifications = relationship("ClinicFileVerifications", back_populates="last_changed_by_user") \ No newline at end of file + clinic_file_verifications = relationship("ClinicFileVerifications", back_populates="last_changed_by_user") + + # Stripe relationships + stripe_user = relationship("StripeUsers", back_populates="user") \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py index bd3d5fa..3ed9484 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -16,6 +16,8 @@ from .ClinicFileVerifications import ClinicFileVerifications from .OTP import OTP from .ResetPasswordTokens import ResetPasswordTokens from .ClinicOffers import ClinicOffers +from .StripeUsers import StripeUsers +from .PaymentLogs import PaymentLogs __all__ = [ "Users", @@ -35,5 +37,7 @@ __all__ = [ "ClinicFileVerifications", "OTP", "ResetPasswordTokens", - "ClinicOffers" + "ClinicOffers", + "StripeUsers", + "PaymentLogs" ] diff --git a/requirements.txt b/requirements.txt index 15de012..ad429a7 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/services/authService.py b/services/authService.py index b9c1f72..386778b 100644 --- a/services/authService.py +++ b/services/authService.py @@ -55,17 +55,7 @@ class AuthService: async def register(self, user_data: UserCreate, background_tasks=None): response = await self.user_service.create_user(user_data, background_tasks) - user = { - "id": response.id, - "username": response.username, - "email": response.email, - "clinicRole": response.clinicRole, - "userType": response.userType, - "mobile": response.mobile, - "clinicId": response.created_clinics[0].id - } - token = create_jwt_token(user) - return token + return response def blockEmailSNS(self, body: str): # confirm subscription @@ -172,7 +162,7 @@ class AuthService: raise ValidationException("User with same email already exists") user = Users( - username=data.username.lower(), + username=data.username, email=data.email.lower(), password=hashed_password, userType=UserType.SUPER_ADMIN, diff --git a/services/clinicDoctorsServices.py b/services/clinicDoctorsServices.py index f59b221..7e589c8 100644 --- a/services/clinicDoctorsServices.py +++ b/services/clinicDoctorsServices.py @@ -157,13 +157,16 @@ class ClinicDoctorsServices: raise e async def delete_clinic_doctor(self, clinic_doctor_id: int): - clinic_doctor = ( - self.db.query(ClinicDoctors) - .filter(ClinicDoctors.id == clinic_doctor_id) - .first() - ) - self.db.delete(clinic_doctor) - self.db.commit() + try: + clinic_doctor = ( + self.db.query(ClinicDoctors) + .filter(ClinicDoctors.id == clinic_doctor_id) + .first() + ) + self.db.delete(clinic_doctor) + self.db.commit() + except Exception as e: + raise e async def get_doctor_status_count(self): diff --git a/services/clinicServices.py b/services/clinicServices.py index 0211baf..8378d63 100644 --- a/services/clinicServices.py +++ b/services/clinicServices.py @@ -265,6 +265,10 @@ class ClinicServices: pass return + + async def get_clinic_offer_by_clinic_email(self, clinic_email: str): + clinic_offer = self.db.query(ClinicOffers).filter(ClinicOffers.clinic_email == clinic_email).first() + return clinic_offer async def get_clinic_offers(self, user, limit:int, offset:int, search:str = ""): diff --git a/services/stripeServices.py b/services/stripeServices.py new file mode 100644 index 0000000..b31cbe0 --- /dev/null +++ b/services/stripeServices.py @@ -0,0 +1,230 @@ +import json +import os +import dotenv + +dotenv.load_dotenv() + +from models import Clinics +import uuid +from fastapi import Request +from datetime import datetime +from models import PaymentLogs +from enums.enums import ClinicStatus + +from database import get_db +from sqlalchemy.orm import Session +import stripe +from loguru import logger + + +class StripeServices: + def __init__(self): + self.db: Session = next(get_db()) + self.logger = logger + self.webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET") + + async def create_customer(self, user_id: int, email: str, name: str): + try: + customer = stripe.Customer.create( + email=email, name=name, metadata={"user_id": user_id} + ) + return customer + except stripe.error.StripeError as e: + self.logger.error(f"Error creating customer: {e}") + raise + + async def delete_customer(self, customer_id: str): + try: + stripe.Customer.delete(customer_id) + except stripe.error.StripeError as e: + self.logger.error(f"Error deleting customer: {e}") + raise + + async def create_account(self, user_id: int, email: str, name: str, phone: str): + try: + account = stripe.Account.create( + type="express", + country="AU", + capabilities={ + "card_payments": {"requested": True}, + "transfers": {"requested": True}, + }, + business_type="individual", + individual={"first_name": name, "email": email}, + metadata={"user_id": user_id}, + ) + return account + except stripe.error.StripeError as e: + self.logger.error(f"Error creating account: {e}") + raise + + async def delete_account(self, account_id: str): + try: + stripe.Account.delete(account_id) + except stripe.error.StripeError as e: + self.logger.error(f"Error deleting account: {e}") + raise + + async def create_checkout_session(self, user_id: int): + try: + checkout_session = stripe.checkout.Session.create( + payment_method_types=["card"], + line_items=[ + { + "price_data": { + "currency": "aud", + "product_data": { + "name": "Willio Voice Subscription", + }, + "unit_amount": 5000, + }, + "quantity": 1, + } + ], + expand=["payment_intent"], + mode="payment", + payment_intent_data={"metadata": {"order_id": "1"}}, + success_url="http://54.79.156.66/", + cancel_url="http://54.79.156.66/", + metadata={"user_id": user_id}, + ) + return checkout_session + except stripe.error.StripeError as e: + self.logger.error(f"Error creating checkout session: {e}") + raise + + async def create_setup_fees(self, customer_id: str, amount: int): + try: + setup_intent = stripe.InvoiceItem.create( + customer=customer_id, + amount=amount, + currency="aud", + description="Setup Fees", + ) + return setup_intent + except stripe.error.StripeError as e: + self.logger.error(f"Error creating setup intent: {e}") + raise + + async def create_subscription_checkout(self, fees_to_be: dict, clinic_id: int, account_id: str, customer_id: str): + try: + + unique_id = str(uuid.uuid4()) + unique_clinic_id = f"clinic_{clinic_id}_{unique_id}" + + line_items = [{ + 'price_data': { + 'currency': 'aud', + 'product_data': { + 'name': 'Monthly Subscription', + }, + 'unit_amount': fees_to_be["subscription_fees"], + 'recurring': { + 'interval': 'year', + }, + }, + 'quantity': 1, + }] + + line_items.append({ + 'price_data': { + 'currency': 'aud', + 'product_data': { + 'name': 'Per Call', + }, + 'unit_amount': fees_to_be["per_call_charges"], + }, + 'quantity': 1, + }) + + line_items.append({ + 'price_data': { + 'currency': 'aud', + 'product_data': { + 'name': 'Setup Fee', + }, + 'unit_amount': fees_to_be["setup_fees"], + }, + 'quantity': 1, + }) + + metadata = { + "clinic_id": clinic_id, + "unique_clinic_id": unique_clinic_id, + "account_id": account_id, + "customer_id": customer_id + } + + session_data = { + 'customer': customer_id, + 'payment_method_types': ['card'], + 'mode': 'subscription', + 'line_items': line_items, + 'success_url': 'http://54.79.156.66/', + 'cancel_url': 'http://54.79.156.66/', + 'metadata': metadata, + 'subscription_data': { + 'metadata': metadata + } + } + + session = stripe.checkout.Session.create(**session_data) + + payment_log = PaymentLogs( + customer_id=customer_id, + account_id=account_id, + amount=fees_to_be["total"], + clinic_id=clinic_id, + unique_clinic_id=unique_clinic_id, + payment_status="pending", + metadata_logs=json.dumps(metadata) + ) + self.db.add(payment_log) + self.db.commit() + + return session + except stripe.error.StripeError as e: + self.logger.error(f"Error creating checkout session: {e}") + raise + + + async def handle_webhook(self, request: Request): + try: + payload = await request.body() + event = stripe.Webhook.construct_event( + payload, request.headers.get("Stripe-Signature"), self.webhook_secret + ) + self.logger.info(f"Stripe webhook event type: {event['type']}") + + if event["type"] == "invoice.payment_succeeded": + pass + + # if event["type"] == "payment_intent.succeeded": + # await self.update_payment_log(event["data"]["object"]["metadata"]["unique_clinic_id"]) + + elif event["type"] == "checkout.session.completed": + await self.update_payment_log(event["data"]["object"]["metadata"]["unique_clinic_id"], event["data"]["object"]["metadata"]["clinic_id"]) + + return event + except ValueError as e: + self.logger.error(f"Invalid payload: {e}") + raise + except stripe.error.SignatureVerificationError as e: + self.logger.error(f"Invalid signature: {e}") + raise + + async def update_payment_log(self, unique_clinic_id:str, clinic_id:int): + try: + payment_log = self.db.query(PaymentLogs).filter(PaymentLogs.unique_clinic_id == unique_clinic_id).first() + if payment_log: + payment_log.payment_status = "success" + self.db.commit() + + clinic = self.db.query(Clinics).filter(Clinics.id == clinic_id).first() + if clinic: + clinic.status = ClinicStatus.UNDER_REVIEW + self.db.commit() + + except Exception as e: + self.logger.error(f"Error updating payment log: {e}") + raise \ No newline at end of file diff --git a/services/userServices.py b/services/userServices.py index 8de47bc..dede4d1 100644 --- a/services/userServices.py +++ b/services/userServices.py @@ -10,21 +10,29 @@ from enums.enums import ClinicStatus, UserType from schemas.UpdateSchemas import UserUpdate from exceptions.unauthorized_exception import UnauthorizedException from interface.common_response import CommonResponse -from exceptions.business_exception import BusinessValidationException -from models import ClinicFileVerifications +from models import ClinicFileVerifications, StripeUsers +from services.stripeServices import StripeServices from utils.password_utils import hash_password from schemas.CreateSchemas import UserCreate from exceptions.resource_not_found_exception import ResourceNotFoundException from exceptions.db_exceptions import DBExceptionHandler from sqlalchemy.orm import joinedload from services.emailService import EmailService - +from services.clinicServices import ClinicServices +from services.dashboardService import DashboardService class UserServices: def __init__(self): self.db: Session = next(get_db()) self.email_service = EmailService() + self.stripe_service = StripeServices() + self.clinic_service = ClinicServices() + self.dashboard_service = DashboardService() async def create_user(self, user_data: UserCreate, background_tasks=None): + + stripe_customer = None + stripe_account = None + # Start a transaction try: user = user_data.user @@ -53,6 +61,20 @@ class UserServices: # Add user to database but don't commit yet self.db.add(new_user) self.db.flush() # Flush to get the user ID without committing + + # Create stripe customer + stripe_customer = await self.stripe_service.create_customer(new_user.id, user.email, user.username) + + # Create stripe account + stripe_account = await self.stripe_service.create_account(new_user.id, user.email, user.username, user.mobile) + + # Create stripe user + stripe_user = StripeUsers( + user_id=new_user.id, + customer_id=stripe_customer.id, + account_id=stripe_account.id + ) + self.db.add(stripe_user) # Get clinic data clinic = user_data.clinic @@ -64,9 +86,7 @@ class UserServices: if existing_clinic: # This will trigger rollback in the exception handler - raise ValidationException("Clinic with same domain already exists") - - + raise ValidationException("Clinic with same domain already exists") # Create clinic instance new_clinic = Clinics( @@ -94,7 +114,7 @@ class UserServices: voice_model_gender=clinic.voice_model_gender, scenarios=clinic.scenarios, general_info=clinic.general_info, - status=ClinicStatus.UNDER_REVIEW, #TODO: change this to PAYMENT_DUE + status=ClinicStatus.PAYMENT_DUE, #TODO: change this to PAYMENT_DUE domain=domain, creator_id=new_user.id, # Set the creator_id to link the clinic to the user who created it ) @@ -119,17 +139,42 @@ class UserServices: background_tasks.add_task(self._send_emails_to_admins, clinic.email) # If no background_tasks provided, we don't send emails - return new_user + offer = await self.clinic_service.get_clinic_offer_by_clinic_email(clinic.email) + + signup_pricing = await self.dashboard_service.get_signup_pricing_master() + + fees_to_be = { + "setup_fees": signup_pricing.setup_fees, + "subscription_fees": signup_pricing.subscription_fees, + "per_call_charges": signup_pricing.per_call_charges, + "total": signup_pricing.setup_fees + signup_pricing.subscription_fees + signup_pricing.per_call_charges + } + + if offer: + fees_to_be["setup_fees"] = offer.setup_fees + fees_to_be["per_call_charges"] = offer.per_call_charges + fees_to_be["total"] = offer.setup_fees + fees_to_be["subscription_fees"] + offer.per_call_charges + + payment_link = await self.stripe_service.create_subscription_checkout(fees_to_be, new_clinic.id, stripe_account.id,stripe_customer.id) + + return payment_link.url except Exception as e: logger.error(f"Error creating user: {str(e)}") # Rollback the transaction if any error occurs self.db.rollback() - + + # Delete stripe customer and account + if stripe_customer: + await self.stripe_service.delete_customer(stripe_customer.id) + if stripe_account: + await self.stripe_service.delete_account(stripe_account.id) + # Use the centralized exception handler DBExceptionHandler.handle_exception(e, context="creating user") finally: self.db.commit() + async def get_user(self, user_id) -> UserResponse: try: # Query the user by ID and explicitly load the created clinics relationship @@ -254,4 +299,13 @@ class UserServices: ) except Exception as e: # Log the error but don't interrupt the main flow - logger.error(f"Error sending admin emails: {str(e)}") \ No newline at end of file + logger.error(f"Error sending admin emails: {str(e)}") + + async def create_payment_link(self, user_id: int): + user = self.db.query(Users).filter(Users.id == user_id).first() + + if not user: + logger.error("User not found") + raise ResourceNotFoundException("User not found") + + return self.stripe_service.create_payment_link(user_id) \ No newline at end of file