import jwt
import os
from datetime import datetime, timedelta, timezone
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from jose import JWTError

JWT_SECRET = os.getenv("JWT_SECRET_KEY")
JWT_ALGO = "HS256"

def generate_access_token(user_id: str, session_id: str):
    payload = {
        "sub": user_id,
        "session_id": session_id,  
        "type": "access",
        "iat": int(datetime.now(timezone.utc).timestamp()),
        "exp": datetime.utcnow() + timedelta(minutes=15)
    }
    return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGO)

def generate_refresh_token(user_id: str,  session_id: str):
    payload = {
        "sub": user_id,
        "session_id": session_id,
        "type": "refresh",
        "iat": int(datetime.now(timezone.utc).timestamp()),
        "exp": datetime.utcnow() + timedelta(days=7)
    }
    return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGO)



def verify_access_token(token: str):
    try:
        payload = jwt.decode(
            token,
            JWT_SECRET,
            JWT_ALGO
        )
        if payload.get("type") != "access":
            return None

        if not payload.get("sub") or not payload.get("session_id"):
            return None
        
        return payload
    except (ExpiredSignatureError, InvalidTokenError):
        return None
    
    

def generate_refresh_token(user_id: str, session_id: str):
    payload = {
        "user_id": user_id,
        "session_id": session_id,
        "type": "refresh",
        "exp": int((datetime.now(timezone.utc) + timedelta(days=7)).timestamp()),
        "iat": int(datetime.now(timezone.utc).timestamp())
    }

    return jwt.encode(
        payload,
        JWT_SECRET,
        algorithm=JWT_ALGO
    )


def verify_refresh_token(token: str):
    try:
        payload = jwt.decode(
            token,
            JWT_SECRET,
            algorithms=[JWT_ALGO],
        )

        # 🔒 Must be refresh token
        if payload.get("type") != "refresh":
            return None

        # ⏰ Expiry check
        if payload.get("exp") < int(datetime.now(timezone.utc).timestamp()):
            return None

        # 🔑 Required fields
        if not payload.get("user_id") or not payload.get("session_id"):
            return None

        return payload

    except JWTError:
        return None





def create_oauth_access_token(user_id,email, client_id):
    payload = {
        "sub": str(user_id),
        "email": email,
        "aud": client_id,
        "iss": "auth.yourdomain.com",
        "exp": datetime.utcnow() + timedelta(hours=1),
        "iat": int(datetime.now(timezone.utc).timestamp())
    }
    
    print(payload)

    return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGO)
