try:
    from xgboost.callback import TrainingCallback as XGBTrainingCallback
except Exception:
    class XGBTrainingCallback:
        pass


from IPython.display import display

from .metrics_plotter import MetricsPlotter


class XGBPlottingCallback(XGBTrainingCallback):
    '''XGBoost callback with metrics plotting widget from CatBoost
    '''

    def __init__(self, total_iterations: int):
        self.plotter = None
        self.total_iterations = total_iterations

    def after_iteration(self, model, epoch, evals_log):

        data_names = evals_log.keys()
        # if more than one sample is passed, consider first as train sample
        first_train = (
            all(['valid' in data_name.lower() for data_name in data_names])
            and len(data_names) > 1
        )

        for data_name, metrics_info in evals_log.items():
            if "train" in data_name.lower() or first_train:
                train = True
                first_train = False
            elif "valid" in data_name.lower() or "test" in data_name.lower():
                train = False
            else:
                raise Exception("Unexpected sample name during evaluation")

            metrics = {name: values[-1] for name, values in metrics_info.items()}

            if self.plotter is None:
                names = list(metrics.keys())
                self.plotter = MetricsPlotter(names, names, self.total_iterations)
                display(self.plotter._widget)

            self.plotter.log(epoch, train, metrics)

        # False to indicate training should not stop.
        return False


def lgbm_plotting_callback():
    """LightGBM callback with metrics plotting widget from CatBoost
    """

    plotter = None

    def _init(env):
        train_metrics = []
        test_metrics = []

        for item in env.evaluation_result_list:
            assert len(item) == 4, "Plotting was run in not suppored mode"
            data_name, eval_name = item[:2]

            if "train" in data_name.lower():
                train_metrics.append(eval_name)
            elif "valid" in data_name.lower() or "test" in data_name.lower():
                test_metrics.append(eval_name)
            else:
                raise Exception("Unexpected sample name during evaluation")

        nonlocal plotter
        plotter = MetricsPlotter(
            train_metrics, test_metrics, env.end_iteration - env.begin_iteration)

        display(plotter._widget)

    def _callback(env):
        if plotter is None:
            _init(env)

        metrics = {"train": {}, "test": {}}

        for item in env.evaluation_result_list:
            data_name, eval_name, result = item[:3]

            if "train" in data_name.lower():
                metrics["train"][eval_name] = result
            elif "valid" in data_name.lower() or "test" in data_name.lower():
                metrics["test"][eval_name] = result
            else:
                raise Exception("Unexpected sample name during evaluation")

        plotter.log(
            env.iteration - env.begin_iteration,
            train=True,
            metrics=metrics["train"]
        )
        plotter.log(
            env.iteration - env.begin_iteration,
            train=False,
            metrics=metrics["test"]
        )

    _callback.order = 20
    return _callback
