from datetime import datetime, timedelta
from core.db.mongo import otp_collection, otp_ip_block_collection, recovery_verifications_collection, account_recovery_otp_collection,account_recovery_rate_limit_collection
from bson import ObjectId
import os
from pymongo import DESCENDING
from datetime import datetime, timedelta, timezone


OTP_EXPIRY_MINUTES = float(os.getenv("OTP_EXPIRY_MINUTES", 10))
OTP_GAP_SECONDS = float(os.getenv("OTP_GAP_SECONDS", 30))
MAX_OTP_REQUESTS = float(os.getenv("MAX_OTP_REQUESTS", 3))
BLOCK_HOURS = float(os.getenv("OTP_BLOCK_HOURS", 3))


def is_ip_blocked(ip: str):
    block = otp_ip_block_collection.find_one({
        "ip_address": ip,
        "blocked_until": {"$gt": datetime.utcnow()}
    })
    return block is not None


def block_ip(ip: str, user_id: str):
    otp_ip_block_collection.update_one(
        {"ip_address": ip},
        {
            "$set": {
                "user_id": ObjectId(user_id),
                "blocked_until": datetime.utcnow() + timedelta(hours=BLOCK_HOURS),
                "created_at": datetime.utcnow()
            }
        },
        upsert=True
    )



def is_ip_blocked(ip: str):
    block = otp_ip_block_collection.find_one({
        "ip_address": ip,
        "blocked_until": {"$gt": datetime.utcnow()}
    })
    return block is not None



def can_send_otp(user_id: str, ip: str, identifier, purpose="primary"):
    now = datetime.utcnow()

    record = otp_ip_block_collection.find_one({"ip_address": ip})

    # 1️⃣ Already blocked → DO NOT extend block
    if record and record.get("blocked_until") and record["blocked_until"] > now:
        return False, f"Too many requests. Try after {BLOCK_HOURS} hours."

    # 🔄 Block expired → reset attempts
    if record and record.get("blocked_until") and record["blocked_until"] <= now:
        otp_ip_block_collection.update_one(
            {"ip_address": ip},
            {
                "$set": {
                    "attempts": 0,
                    "blocked_until": None,
                    "last_sent_at": None
                }
            }
        )
        record["attempts"] = 0
        record["blocked_until"] = None

    # 2️⃣ First request
    if not record:
        otp_ip_block_collection.insert_one({
            "ip_address": ip,
            "user_id": ObjectId(user_id),
            "attempts": 1,
            "first_attempt_at": now,
            "last_sent_at": now,
            "blocked_until": None,
            "purpose": purpose,
            "identifier": identifier
        })
        return True, None

    # 3️⃣ Enforce 30s gap
    diff = (now - record["last_sent_at"]).total_seconds()
    if diff < OTP_GAP_SECONDS:
        return False, f"Wait {int(OTP_GAP_SECONDS - diff)} seconds"

    # 4️⃣ Increment attempt
    attempts = record["attempts"] + 1

    # 🚫 Apply block ONLY ONCE
    if attempts >= MAX_OTP_REQUESTS+1:
        otp_ip_block_collection.update_one(
            {"ip_address": ip},
            {
                "$set": {
                    "attempts": attempts,
                    "blocked_until": now + timedelta(hours=BLOCK_HOURS),
                }
            }
        )
        return False, f"Too many OTP requests. Blocked for {BLOCK_HOURS} hours."

    # ✅ Allow OTP
    otp_ip_block_collection.update_one(
        {"ip_address": ip},
        {
            "$set": {"last_sent_at": now},
            "$inc": {"attempts": 1}
        }
    )

    return True, None


def save_otp(user_id: str, otp_hash: str, ip: str,to_verify, identifier, purpose="primary" ):
    otp_collection.insert_one({
        "user_id": ObjectId(user_id),
        "otp_hash": otp_hash,
        "ip_address": ip,
        "created_at": datetime.utcnow(),
        "expires_at": datetime.utcnow() + timedelta(minutes=OTP_EXPIRY_MINUTES),
        "verified": False,
        "purpose" : purpose,
        "to_verify" : to_verify,
        "identifier" : identifier
        
    })




def get_valid_otp(user_id: str,to_verify,identifier ,purpose="primary"):
    return otp_collection.find_one({
        "user_id": ObjectId(user_id),
        "verified": False,
        "expires_at": {"$gte": datetime.utcnow()},
        "purpose" : purpose,
        "to_verify" : to_verify,
        "identifier" : identifier,
    }, sort=[("_id", DESCENDING)])





MAX_VERIFY_ATTEMPTS = 5


def increment_verify_attempt(otp_id):
    otp_collection.update_one(
        {"_id": ObjectId(otp_id)},
        {"$inc": {"verify_attempts": 1}}
    )


def mark_otp_verified(otp_id):
    otp_collection.update_one(
        {"_id": ObjectId(otp_id)},
        {"$set": {"verified": True}}
    )


def delete_all_user_otps(user_id, purpose="primary"):
    otp_collection.delete_many({"user_id": ObjectId(user_id), purpose: purpose})


def clear_ip_block(ip: str, user_id: str):
    otp_ip_block_collection.delete_one({
        "ip_address": ip,
        "user_id": ObjectId(user_id)
    })




    
    
def save_temp_recovery_otp(user_id, identifier, otp_hash, purpose):
    """
    purpose: RECOVERY_EMAIL or RECOVERY_PHONE
    """
    recovery_verifications_collection.insert_one({
        "user_id": ObjectId(user_id),
        "identifier": identifier,
        "purpose": purpose,
        "otp_hash": otp_hash,
        "verify_attempts": 0,
        "verified": False,
        "created_at": datetime.utcnow(),
        "expires_at": datetime.utcnow() + timedelta(minutes=10)
    })
    
    
    
    
    
       
ACCOUNT_OTP_GAP_SECONDS = os.getenv("ACCOUNT_OTP_GAP_SECONDS", 30)
ACCOUNT_OTP_MAX_REQUESTS = os.getenv("ACCOUNT_OTP_MAX_REQUESTS", 3)
ACCOUNT_OTP_BLOCK_HOURS = os.getenv("ACCOUNT_OTP_BLOCK_HOURS", 3)
ACCOUNT_OTP_EXPIRY_MINUTES = os.getenv("ACCOUNT_OTP_EXPIRY_MINUTES", 10)



def can_send_account_otp(ip, identifier, purpose="ACCOUNT_RECOVERY"):
    now = datetime.now(timezone.utc)

    record = account_recovery_rate_limit_collection.find_one({
        "ip_address": ip,
        "identifier": identifier
    })

    # 1️⃣ First request
    if not record:
        account_recovery_rate_limit_collection.insert_one({
            "ip_address": ip,
            "identifier": identifier,
            "attempts": 1,
            "first_attempt_at": now,
            "last_sent_at": now,
            "blocked_until": None
        })
        return True, None

    # Normalize timestamps
    last_sent_at = record.get("last_sent_at")
    if last_sent_at and last_sent_at.tzinfo is None:
        last_sent_at = last_sent_at.replace(tzinfo=timezone.utc)

    blocked_until = record.get("blocked_until")
    if blocked_until:
        if blocked_until.tzinfo is None:
            blocked_until = blocked_until.replace(tzinfo=timezone.utc)

        if blocked_until > now:
            return False, f"Too many requests. Try again after {ACCOUNT_OTP_BLOCK_HOURS} hours."

        # 🔥 Block expired → RESET
        account_recovery_rate_limit_collection.update_one(
            {"_id": record["_id"]},
            {"$set": {
                "attempts": 0,
                "blocked_until": None,
                "first_attempt_at": now,
                "last_sent_at": None
            }}
        )
        record["attempts"] = 0
        last_sent_at = None

    # ⏱ Gap check
    if last_sent_at:
        diff = (now - last_sent_at).total_seconds()
        if diff < ACCOUNT_OTP_GAP_SECONDS:
            return False, f"Wait {int(ACCOUNT_OTP_GAP_SECONDS - diff)} seconds"

    attempts = record["attempts"] + 1

    # 🔒 Block once
    if attempts > ACCOUNT_OTP_MAX_REQUESTS:
        account_recovery_rate_limit_collection.update_one(
            {"_id": record["_id"]},
            {"$set": {
                "attempts": attempts,
                "blocked_until": now + timedelta(hours=ACCOUNT_OTP_BLOCK_HOURS)
            }}
        )
        return False, f"Too many requests. Try again after {ACCOUNT_OTP_BLOCK_HOURS} hours."

    # ✅ Normal update
    account_recovery_rate_limit_collection.update_one(
        {"_id": record["_id"]},
        {
            "$set": {"last_sent_at": now},
            "$inc": {"attempts": 1}
        }
    )

    return True, None


def save_account_otp(identifier, otp_hash, ip, meta=None, purpose="ACCOUNT_RECOVERY"):
    now = datetime.now(timezone.utc)
    account_recovery_otp_collection.insert_one({
        "identifier": identifier,
        "otp_hash": otp_hash,
        "ip_address": ip,
        "purpose": purpose,
        "meta": meta or {},
        "attempts": 1,
        "first_attempt_at": now,
        "last_sent_at": now,
        "blocked_until": None,
        "verified": False,
        "created_at": now,
        "expires_at": now + timedelta(minutes=ACCOUNT_OTP_EXPIRY_MINUTES)
    })
