import logging
from typing import Tuple, List, Set

from median.models import (
    ListeValide,
    ItemValide,
    Product,
    Peigne,
    ListeModel,
    ListeItemModel,
    Service,
    ReferencePerService,
    Gpao,
    Patient,
)
from median.constant import TypeServiListe, ReferenceDistributionType, TypeEtatGpao
from peewee import fn, JOIN, Value
from common.status import HTTP_200_OK, HTTP_500_INTERNAL_SERVER_ERROR
from .stock_service import StockService

logger = logging.getLogger("median.completion")


class TrayService:
    """Service for handling tray operations"""

    def __init__(self):
        self.stock_service = StockService()

    def get_tray_of_pillboxes(self, code: str, chrono: str = None, id_chargement: str = None) -> dict:
        """Main function for tray retrieval"""
        is_service: Service = Service.get_or_none(Service.code == code)
        mode_is_global = is_service is not None

        try:
            if mode_is_global:
                data, tray_code = self._get_global_tray(code, chrono, id_chargement)
            else:
                data, tray_code = self._get_nominative_tray(code)

            return self._build_tray_response(data, tray_code, mode_is_global)
        except Exception as err:
            logger.error(err)
            return {"error": f"An error has occured ({str(err)})"}, HTTP_500_INTERNAL_SERVER_ERROR

    def _get_nominative_tray(self, pillbox_code: str) -> Tuple[List, Set]:
        """Handle nominative mode tray retrieval"""
        # Fetch the last known liste_valide using this pillbox code
        mostRecentlisteValide: ListeValide = (
            ListeValide.select(ListeValide)
            .where(
                # The model is a charfield, some bases will use an int, the string comparison isn't good so
                # I'm looking also for the length.
                (ListeValide.id_pilulier == pillbox_code)
                & (fn.CHAR_LENGTH(ListeValide.id_pilulier) == len(pillbox_code))
            )
            .order_by(-ListeValide.pk)
            .limit(1)
        )

        if mostRecentlisteValide.count() == 0:
            logger.info(f"Pillbox {pillbox_code} not found")
            raise Exception(f"Pillbox {pillbox_code} not found")

        if mostRecentlisteValide.count() > 1:
            logger.error(f"Fetched more than 1 list, this should not happen! code:{pillbox_code}")
            raise Exception(f"Error while fetching Pillbox {pillbox_code}")

        # Check if the pillbox isn't a global type
        if mostRecentlisteValide and mostRecentlisteValide[0].type_servi == TypeServiListe.GlobalePilulier.value:
            errMsg = "This pillbox has last been used for a global dispensation, please use the global mode instead."
            logger.error(errMsg)
            raise Exception(errMsg)

        # Find the spaces available on the tray
        pillboxTypeCode = int(pillbox_code[:3])
        peigne: Peigne = Peigne.get_or_none(Peigne.type_peigne == pillboxTypeCode)
        nbSpacesInTray = peigne.nb_pilulier if peigne is not None else 7
        logger.info(f"Showing {nbSpacesInTray} spaces using the type {pillboxTypeCode}")

        plateauId: ListeValide = mostRecentlisteValide[0].id_plateau
        pillboxesOnThisTray = (
            ListeValide.select(ListeValide)
            .where(ListeValide.id_plateau == plateauId)
            .order_by(+ListeValide.pos_pilulier)
        )

        # Calculate the distinct containers
        distinctPillboxCodes = []
        for liste in pillboxesOnThisTray:
            if liste.id_pilulier not in distinctPillboxCodes:
                distinctPillboxCodes.append(liste.id_pilulier)

        if len(distinctPillboxCodes) > nbSpacesInTray:
            logger.error(
                f"Incoherent tray spaces count. nbPillboxes = {pillboxesOnThisTray} > tray spaces = {nbSpacesInTray}"
            )
            raise Exception("Error while fetching the tray data")

        data = [None] * nbSpacesInTray
        tray_code = set()

        for index, physicalContainer in enumerate(distinctPillboxCodes):
            data[index] = {
                "container": physicalContainer,
                "pillboxes": [],
            }

            pillbox: ListeValide
            for pillbox in pillboxesOnThisTray:
                if not pillbox.id_pilulier == physicalContainer:
                    continue

                pillbox_data = self._build_nominative_pillbox_data(pillbox)
                data[index]["pillboxes"].append(pillbox_data)

                items = (
                    ItemValide.select(
                        ItemValide,
                        Product,
                        fn.GROUP_CONCAT(Gpao.pk.distinct()).alias("gpao_ids"),
                        fn.GROUP_CONCAT(Gpao.item_wms.distinct()).alias("item_wms_ids"),
                    )
                    .join(Product, JOIN.LEFT_OUTER, on=(Product.reference == ItemValide.reference))
                    .join_from(
                        ItemValide,
                        Gpao,
                        JOIN.LEFT_OUTER,
                        on=(Gpao.item_wms == ItemValide.item_wms)
                        & (Gpao.liste == pillbox.liste)
                        & (Gpao.item == ItemValide.item)
                        & (Gpao.ref == ItemValide.reference)
                        & (Gpao.fraction == ItemValide.fraction)
                        & (Gpao.id_pilulier == ItemValide.id_pilulier),
                        # NOTE: This might be overkill, but it ensures we really get the right gpao items
                    )
                    .group_by(ItemValide.pk)
                    .where(ItemValide.liste_pk == pillbox.pk)
                )

                for item in items:
                    tray_code.add(item.id_plateau)
                    item_data = self._build_nominative_item_data(item, pillbox.service)
                    pillbox_data["items"].append(item_data)

        return data, tray_code

    def _get_global_tray(self, code: str, chrono: str, id_chargement: str) -> Tuple[List, Set]:
        """Handle global mode tray retrieval"""
        nbSpacesInTray = 1
        logger.info(f"Showing {nbSpacesInTray} spaces for a global tray")

        try:
            # fetch from ListeItemModel
            item_query = (
                ListeItemModel.select(
                    ListeModel.pk.alias("liste_pk"),
                    ListeItemModel.pk.alias("item_pk"),
                    ListeModel.id_chargement,
                    ListeModel.service,
                    ListeItemModel.reference,
                    ListeItemModel.alveole_theo,
                    ListeItemModel.fraction,
                    ListeItemModel.id_plateau,
                    ListeItemModel.item,
                    ListeItemModel.qte_dem,
                    ListeItemModel.qte_serv,
                    ListeItemModel.dtprise,
                    ListeItemModel.heure,
                    Product,
                    Value("item").alias("source"),
                    ListeModel.date_creation.alias("date_creation"),
                    Value(0).alias("solde"),  # Placeholder for solde
                    Value(0).alias("wms"),  # Placeholder for wms
                )
                .join(ListeModel, JOIN.INNER, on=(ListeModel.liste == ListeItemModel.liste))
                .switch(ListeItemModel)
                .join(Product, JOIN.LEFT_OUTER, on=(Product.reference == ListeItemModel.reference))
                .where(
                    ListeItemModel.type_servi
                    << [TypeServiListe.GlobalePilulier.value, TypeServiListe.GlobaleBoite.value]
                )
            )

            # fetch from ItemValide
            item_valide_query = (
                ItemValide.select(
                    ListeValide.pk.alias("liste_pk"),
                    ItemValide.pk.alias("item_pk"),
                    ListeValide.id_chargement,
                    ListeValide.service,
                    ItemValide.reference,
                    ItemValide.alveole_theo,
                    ItemValide.fraction,
                    ItemValide.id_plateau,
                    ItemValide.item,
                    ItemValide.quantite_dem.alias("qte_dem"),
                    ItemValide.quantite_serv.alias("qte_serv"),
                    ItemValide.dtprise,
                    ItemValide.heure,
                    Product,
                    Value("item_valide").alias("source"),
                    fn.DATE(ListeValide.chrono).alias("date_creation"),
                    fn.COALESCE(Gpao.solde, 0).alias("solde"),
                    Gpao.item_wms,
                )
                .join(ListeValide, JOIN.INNER, on=(ListeValide.pk == ItemValide.liste_pk))
                .join_from(ItemValide, Product, JOIN.LEFT_OUTER, on=(Product.reference == ItemValide.reference))
                .join_from(
                    ItemValide,
                    Gpao,
                    JOIN.LEFT_OUTER,
                    on=((Gpao.liste == ListeValide.liste) & (Gpao.item == ItemValide.item) & (Gpao.solde == 1)),
                )
                .where(
                    ItemValide.type_servi << [TypeServiListe.GlobalePilulier.value, TypeServiListe.GlobaleBoite.value]
                )
            )

            if id_chargement:
                item_query = item_query.where(
                    (ListeModel.service == code) & (ListeModel.id_chargement == id_chargement)
                )
                item_valide_query = item_valide_query.where(
                    (ListeValide.service == code) & (ListeValide.id_chargement == id_chargement)
                )
            else:
                item_query = item_query.where((ListeModel.service == code) & (fn.DATE(ListeModel.ddeb) == chrono))
                item_valide_query = item_valide_query.where(
                    (ListeValide.service == code) & (fn.DATE(ListeValide.ddeb) == chrono)
                )

            # Combine both queries
            global_lists = item_query.union(item_valide_query)

        except Exception as e:
            logger.error(f"Error fetching global lists: {str(e)}")
            raise Exception(f"Failed to fetch global lists: {str(e)}")

        GLOBAL_TERM = "GLOBAL"
        pillbox_data = self._build_global_pillbox_data(global_lists, GLOBAL_TERM)

        data = [None] * nbSpacesInTray
        data[0] = {
            "container": code,
            "pillboxes": [pillbox_data],
        }

        tray_code = set()

        for item in global_lists:
            if item.id_plateau:  # Only add if id_plateau exists
                tray_code.add(item.id_plateau)
            item_data = self._build_global_item_data(item, code)
            pillbox_data["items"].append(item_data)

        return data, tray_code

    def _build_tray_response(self, data: List, tray_code: Set, mode_is_global: bool = False) -> dict:
        """Build the final tray response with validation"""
        # In global mode, we don't care about the number of trays
        if not mode_is_global and len(tray_code) > 1:
            errormsg = f"Several trays have been found for the same pillbox: {tray_code}"
            logger.error(errormsg)
            return {"error": errormsg}, HTTP_500_INTERNAL_SERVER_ERROR

        return {"tray": data}, HTTP_200_OK

    def _build_nominative_pillbox_data(self, pillbox: ListeValide) -> dict:
        """Build pillbox data structure for nominative mode"""

        patient: Patient = Patient.get(Patient.ipp == pillbox.num_ipp)

        return {
            "pk": pillbox.pk_pilulier,
            "plateau": pillbox.id_plateau,
            "chrono": pillbox.chrono.isoformat(),
            "ward": pillbox.service,
            "patient": {
                "ipp": pillbox.num_ipp,
                "first_name": patient.prenom,
                "last_name": patient.nom,
                "maiden_name": patient.nom_jeune_fille,
                "birthdate": patient.date_naissance.isoformat() if patient.date_naissance else None,
                "stay": pillbox.num_sej,
            },
            "chargement": pillbox.id_chargement,
            "pos_pillbox": pillbox.pos_pilulier,
            "pillbox_code": pillbox.id_pilulier,
            "adr_carnet": pillbox.adr_carnet,
            "items": [],
        }

    def _build_global_pillbox_data(self, global_lists: ListeModel, global_term: str) -> dict:
        """Build pillbox data structure for global mode"""
        # Use the first item to extract the information
        global_list = global_lists[0]
        return {
            "pk": global_term,
            "plateau": global_term,
            "chrono": global_list.listemodel.date_creation.isoformat(),
            "ward": global_list.listemodel.service,
            "patient": "",
            "chargement": global_list.listemodel.id_chargement,
            "pos_pillbox": 0,
            "pillbox_code": TypeServiListe.GlobalePilulier.value,
            "adr_carnet": global_term,
            "items": [],
        }

    def _build_nominative_item_data(self, item: ItemValide, service) -> dict:
        """Build item data structure for nominative mode"""

        refDestData: ReferencePerService = ReferencePerService.get_or_none(
            (ReferencePerService.ref_pk == item.product.pk)
            & (ReferencePerService.dest_pk == Service.get(Service.code == service).pk)
            & (ReferencePerService.distribution_type == ReferenceDistributionType.Nominative.value)
        )

        if item.item_wms_ids and Gpao.get_or_none(
            (Gpao.item_wms << item.item_wms_ids.split(","))
            & (Gpao.id_pilulier == item.id_pilulier)
            & (Gpao.etat << [TypeEtatGpao.DONE.value])
            & (Gpao.solde == 1)
        ):
            item.solde = 1

        return {
            "pk": item.pk,
            "reference": item.reference,
            "designation": item.designation,
            "alveole": item.alveole_theo,
            "fraction": item.fraction,
            "id_plateau": item.id_plateau,
            "item_nb": item.item,
            "qty_dem": item.quantite_dem,
            "qty_serv": item.quantite_serv,
            "date_prise": item.dtprise.isoformat(),
            "hour_prise": item.heure,
            "stock_info": self.stock_service.get_stock_info(item.reference, item.fraction),
            "distribution_type": refDestData.distribution if refDestData else None,
            "solde": item.solde if hasattr(item, "solde") else -1,
            "ref_pk": item.product.pk,
            "test": item.gpao_pks if hasattr(item, "gpao_pks") else None,
        }

    def _build_global_item_data(self, item: ListeItemModel, service) -> dict:
        """Build item data structure for global mode (current items)"""

        refDestData: ReferencePerService = ReferencePerService.get_or_none(
            (ReferencePerService.ref_pk == item.product.pk)
            & (ReferencePerService.dest_pk == Service.get(Service.code == service).pk)
            & (ReferencePerService.distribution_type == ReferenceDistributionType.Global.value)
        )

        return {
            "pk": item.item_pk,
            "reference": item.reference,
            "designation": item.product.designation,
            "alveole": item.alveole_theo,
            "fraction": item.fraction,
            "id_plateau": item.id_plateau,
            "item_nb": item.item,
            "qty_dem": item.qte_dem,
            "qty_serv": item.qte_serv,
            "date_prise": item.dtprise.isoformat() if item.dtprise else None,
            "hour_prise": item.heure,
            "stock_info": self.stock_service.get_stock_info(item.reference, item.fraction),
            "distribution_type": refDestData.distribution if refDestData else None,
            "ref_pk": item.product.pk,
            "source": item.source,  # 'item' or 'item_valide'
            "solde": item.solde if hasattr(item, "solde") else None,
            "item_wms": item.wms if hasattr(item, "wms") else None,
        }

    def _build_global_finished_item_data(self, item: ItemValide, service) -> dict:
        """Build item data structure for global mode (finished items)"""

        refDestData: ReferencePerService = ReferencePerService.get_or_none(
            (ReferencePerService.ref_pk == item.product.pk)
            & (ReferencePerService.dest_pk == Service.get(Service.code == service).pk)
            & (ReferencePerService.distribution_type == ReferenceDistributionType.Global.value)
        )

        return {
            "pk": item.pk,
            "reference": item.reference,
            "designation": item.product.designation,
            "alveole": item.alveole_theo,
            "fraction": item.fraction,
            "id_plateau": item.id_plateau,
            "item_nb": item.item,
            "qty_dem": item.quantite_dem,
            "qty_serv": item.quantite_serv,
            "date_prise": item.dtprise.isoformat() if item.dtprise else None,
            "hour_prise": item.heure,
            "stock_info": self.stock_service.get_stock_info(item.reference, item.fraction),
            "distribution_type": refDestData.distribution if refDestData else None,
            "ref_pk": item.product.pk,
        }
