import logging
import os
from datetime import datetime, timedelta, timezone
import bcrypt
import base64
import uuid
import rsa
import secrets
from security import decryptRSA
from flask import session, request, Flask, abort, Response, current_app
from flask_jwt_extended import create_access_token, create_refresh_token, jwt_required, get_jwt_identity
from flask_wtf.csrf import generate_csrf
from median.models import User, HistoriqueIdentification, Profil
from flask_wtf import CSRFProtect
from common.status import (
    HTTP_400_BAD_REQUEST,
    HTTP_401_UNAUTHORIZED,
    HTTP_200_OK,
    HTTP_202_ACCEPTED,
    HTTP_204_NO_CONTENT,
)
from median.database import crypte

logger = logging.getLogger("median.webserver")

# Initialize global variables
csrf = CSRFProtect()
timeoutSession = "TIMEOUT_SESSION" in os.environ and int(os.environ["TIMEOUT_SESSION"]) or 30

# Get secret key from environment or generate one
if os.environ.get("MEDIAN_SECRET_KEY"):
    secret_key = bytes(os.environ["MEDIAN_SECRET_KEY"], encoding="utf8")
else:
    secret_key = os.urandom(16)


def _init_jwt(app: Flask):
    """Initialize JWT configuration"""
    app.config["JWT_SECRET_KEY"] = os.environ.get(
        "MEDIAN_JWT_SECRET_KEY", base64.b64encode(str(uuid.uuid4()).encode("ascii"))
    )
    app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=timeoutSession)
    app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=1)


def init_session(app: Flask):
    """Initialize session configuration"""
    # Set secret key for session encryption
    app.secret_key = secret_key

    # Set session timeout
    app.permanent_session_lifetime = timedelta(minutes=timeoutSession)

    # Initialize CSRF protection
    csrf.init_app(app)

    # Initialize JWT
    _init_jwt(app)


def _populate_rights(usr):
    """Internal function to populate user rights in session"""
    try:
        where_condition = Profil.profil == usr.profil

        if session.get("device"):
            # If this is logged in as a device, only some target rights are allowed
            # TODO: Fixed values for AidePlus, edit this if other device types are made
            where_condition = where_condition & (Profil.ressource.startswith("WEB_AIDEPLUS"))

        rights: Profil = Profil.select(Profil.ressource, Profil.visu, Profil.edit).where(where_condition)

        session["rights"] = {}

        for right in rights:
            session["rights"][right.ressource] = {"visu": right.visu, "edit": right.edit}

    except Exception as e:
        logger.error("Login : Adding rights to session failed : %s" % str(e), exc_info=True)


def populate_session(usr, user_id, app=None):
    """Populate the session with user information and rights"""
    session["username"] = usr.username
    session["email"] = usr.email
    session["user_id"] = user_id
    session["user_lang"] = usr.lang or "en_US"

    # Adding rights to the session
    _populate_rights(usr)

    session.permanent = True
    current_app.permanent_session_lifetime = timedelta(minutes=timeoutSession)


def handle_csrf_error(e):
    """Handle CSRF errors"""
    return abort(HTTP_400_BAD_REQUEST, "Bad request")


def get_csrf_token():
    """Generate and return CSRF token"""
    resp = Response()
    resp.headers["X-CSRF-Token"] = generate_csrf()
    return resp


def generate_public_key():
    """Generate RSA key pair and store private key in session"""
    (publicKey, privateKey) = rsa.newkeys(1024)
    session["privateKey"] = privateKey.save_pkcs1("PEM")
    return {
        "key": publicKey.save_pkcs1("PEM").decode("utf8"),
    }, HTTP_200_OK


@jwt_required(refresh=True)
def refresh_token():
    """Handle JWT token refresh"""
    identity = get_jwt_identity()
    usr = User.get(User.pk == identity)

    access_token = create_access_token(
        identity=str(identity),
        additional_claims={"username": usr.username, "email": usr.email, "user_lang": usr.lang or "en_US"},
    )
    refresh_token = create_refresh_token(identity=str(identity))

    populate_session(usr, usr.pk)

    return {"refreshToken": refresh_token, "token": access_token}, HTTP_200_OK


def _create_user_tokens(usr, user_id):
    """Internal function to create user JWT tokens"""
    refresh_token = create_refresh_token(identity=str(user_id))
    access_token = create_access_token(
        identity=str(user_id),
        additional_claims={"username": usr.username, "email": usr.email, "user_lang": usr.lang or "en_US"},
    )
    return refresh_token, access_token


def _handle_login_success(usr, user_id):
    """Internal function to handle successful login"""
    User.update(retry=5, date_last_login=datetime.now()).where(User.pk == user_id).execute()

    # Register the login in history
    usrHistory = HistoriqueIdentification()
    usrHistory.user = usr.username
    usrHistory.demande = "MWEB: Login"
    usrHistory.save()

    refresh_token, access_token = _create_user_tokens(usr, user_id)
    populate_session(usr, user_id)
    session["last_action_time"] = datetime.now(timezone.utc)

    if usr.isTemporary:
        logger.info("Password is temporary, ask for changes")
        # TODO: Security: Try to see if we can bypass this renewpassword route (and if so, fix it)
        # The session has been made, so technically the user is authenticated
        # that means any route should work even if the user is still in "temporary" mode
        return {"refreshToken": refresh_token, "token": access_token, "route": "/renewpassword.html"}, HTTP_202_ACCEPTED

    logger.info("User: Authentication is OK")
    return {"refreshToken": refresh_token, "token": access_token, "route": "/"}, HTTP_202_ACCEPTED


def login_user():
    """Handle user login"""
    user_id = None
    try:
        user = decryptRSA(request.form["username"], session["privateKey"])
        logger.info(f"User: {user}")
        usr = User.select().where(((User.email == user) | (User.login == user))).get()
        logger.info(f"User {usr.username} line found {usr.pk}")

        user_id = usr.pk

        if not usr.isEnabled:
            logger.warning(f"User: {usr.username} ({usr.pk}) is disable")
            return {"message": "connection.disabled"}, HTTP_401_UNAUTHORIZED

        if usr.retry <= 0:
            logger.warning(f"User: {usr.username} ({usr.pk}) Maximum de tentatives atteintes")
            return {"message": "connection.max_retry"}, HTTP_401_UNAUTHORIZED

    except Exception as e:
        logger.error(str(e.args))
        return {"message": "connection.notexist"}, HTTP_401_UNAUTHORIZED

    try:
        if (user_id is not None) & (usr.passwordWeb is not None):
            password = decryptRSA(request.form["password"], session["privateKey"])

            if bcrypt.checkpw(password.encode("UTF_8"), usr.passwordWeb.encode("UTF_8")):
                return _handle_login_success(usr, user_id)
            else:
                logger.warning("Password mismatch between form and database")
                retry = usr.retry - 1
                User.update(retry=retry).where(User.pk == user_id).execute()
                return {"message": "connection.failed"}, HTTP_401_UNAUTHORIZED
        else:
            logger.warning("User: Failed to retrieve user or passwordweb in database")
            return {"message": "connection.failed"}, HTTP_401_UNAUTHORIZED

    except Exception as e:
        logger.error("User: Failed to retrieve user")
        logger.error(str(e.args))
        return {"message": "connection.failed"}, HTTP_401_UNAUTHORIZED


def login_device(token=None, pincode=None):
    if not token:
        return login_device_prepare(pincode)
    else:
        return login_device_activate(token)


def login_device_activate(token):
    usr: User = User.get_or_none((User.token == token) & (User.token_expires > datetime.now()))

    if usr:
        user_id = usr.pk
        session["device"] = True
        return _handle_login_success(usr, user_id)
    else:
        # Redirect to localhost:80
        logger.warning("Device token not found")
        return {"message": "connection.notexist"}, HTTP_401_UNAUTHORIZED


def login_device_prepare(pincode):
    """Handle device login, using only password"""
    try:
        # TODO: Use some value that will first authenticate the machine

        crypted_pincode = crypte(pincode)
        usr: User = User.select().where((User.password == crypted_pincode)).get()
        logger.info(f"User {usr.username} line found {usr.pk}")

        if not usr.isEnabled:
            logger.warning(f"User: {usr.username} ({usr.pk}) is disable")
            return {"message": "connection.disabled"}, HTTP_401_UNAUTHORIZED

        if usr.retry <= 0:
            logger.warning(f"User: {usr.username} ({usr.pk}) Maximum de tentatives atteintes")
            return {"message": "connection.max_retry"}, HTTP_401_UNAUTHORIZED

        # Generate a key to use in the login process
        secretKey = secrets.token_urlsafe(32)
        usr.token = secretKey
        usr.token_expires = datetime.now() + timedelta(minutes=1)
        usr.save()

        return {"message": f"Ready to connect with user : {usr.username}", "key": secretKey}, HTTP_200_OK

    except Exception as e:
        logger.error(str(e.args))
        return {"message": "connection.notexist"}, HTTP_401_UNAUTHORIZED


def logout_user():
    """Handle user logout"""
    try:
        username = session["username"]
    except Exception:
        username = "-"

    session.clear()
    usrHistory = HistoriqueIdentification()
    usrHistory.user = username
    usrHistory.demande = "MWEB: Logout"
    usrHistory.save()
    logger.info(f"User: logout {username}")
    return {}, HTTP_204_NO_CONTENT


def check_session_status():
    """Check remaining session time"""
    try:
        if "permanent_session_lifetime" not in dir(current_app):
            logger.error("Session bad configuration")
            return {"message": "Session bad configuration"}, HTTP_400_BAD_REQUEST

        if "last_action_time" not in session:
            logger.error("Clear session")
            session.clear()
            return {}, HTTP_204_NO_CONTENT

        session_start = session.get("last_action_time")
        session_lifetime = current_app.permanent_session_lifetime.total_seconds()
        elapsed = (datetime.now(timezone.utc) - session_start).total_seconds()
        remaining = max(0, session_lifetime - elapsed)

        if remaining <= 0:
            return {}, HTTP_204_NO_CONTENT

        return {"remaining": int(remaining)}, HTTP_200_OK
    except RuntimeError:
        # Handle case when outside application context
        return {"message": "Application context not available"}, HTTP_400_BAD_REQUEST


def refresh_session_time():
    """Refresh the session time"""
    try:
        session["last_action_time"] = datetime.now(timezone.utc)
        return {}, HTTP_200_OK
    except RuntimeError:
        return {"message": "Application context not available"}, HTTP_400_BAD_REQUEST


def init_user_password(user_id):
    """Initialize a user's password"""
    password = "D33nov@" + datetime.now().strftime("%H%M%S")

    try:
        usr = User.select().where((User.isEnabled == 1) & (User.pk == user_id)).get()

        User.update(
            passwordWeb=bcrypt.hashpw(password.encode("UTF_8"), bcrypt.gensalt()), isTemporary=1, retry=5
        ).where(User.pk == usr.pk).execute()

        return {"password": password}, HTTP_200_OK

    except Exception:
        logger.error("Fail to initialize user password")
        return {"message": "connection.failed"}, HTTP_400_BAD_REQUEST


def renew_password(password, confirm):
    """Handle password renewal"""
    if password == confirm:
        user_id = session["user_id"]
        try:
            usr = User.select().where((User.isEnabled == 1) & (User.pk == user_id)).get()
            User.update(
                passwordWeb=bcrypt.hashpw(password.encode("UTF_8"), bcrypt.gensalt()), isTemporary=0, retry=5
            ).where(User.pk == usr.pk).execute()

            return {"route": "/reference.html"}, HTTP_202_ACCEPTED
        except Exception:
            logger.error("Renew password failed")
            return {"message": "connection.failed"}, HTTP_400_BAD_REQUEST
    else:
        logger.error("Renew password not equal")
        return {"message": "renewpassword.notequal"}, HTTP_400_BAD_REQUEST
