import os
import json
from threading import Thread, Event
from traitlets import Unicode, Dict, default
from IPython.display import display
from ipywidgets import DOMWidget, Layout, widget_serialization


class MetricWidget(DOMWidget):
    _view_name = Unicode('CatboostWidgetView').tag(sync=True)
    _model_name = Unicode('CatboostWidgetModel').tag(sync=True)
    _view_module = Unicode('catboost-widget').tag(sync=True)
    _model_module = Unicode('catboost-widget').tag(sync=True)
    _view_module_version = Unicode('^1.0.0').tag(sync=True)
    _model_module_version = Unicode('^1.0.0').tag(sync=True)

    data = Dict({}).tag(sync=True, **widget_serialization)

    @default('layout')
    def _default_layout(self):
        return Layout(height='500px', align_self='stretch')


class MetricVisualizer(MetricWidget):
    def __init__(self, train_dirs, subdirs=False):
        super(self.__class__, self).__init__()
        if isinstance(train_dirs, str):
            train_dirs = [train_dirs]
        if subdirs:
            train_subdirs = []
            for train_dir in train_dirs:
                train_subdirs.extend(self._get_subdirectories(train_dir))
            train_dirs = train_subdirs
        self._train_dirs = train_dirs[:]
        self._names = []
        curdir = os.path.abspath(os.path.curdir)
        for train_dir in train_dirs:
            abspath = os.path.abspath(train_dir)
            self._names.append(os.path.basename(abspath) if abspath != curdir else 'current')
        self._need_to_stop = Event()
        self._update_after_stop_signal = False

    def start(self):
        display(self)
        self._update_data()
        while not self._need_to_stop.wait(1.0):
            self._update_data()

        if self._update_after_stop_signal:
            self._update_data()

    def _run_update(self):
        self.thread = Thread(target=self.start, args=())
        self.thread.start()

    def _stop_update(self):
        self._update_after_stop_signal = True
        self._need_to_stop.set()
        self.thread.join()

    def _get_subdirectories(self, a_dir):
        return [os.path.join(a_dir, name) for name in os.listdir(a_dir) if os.path.isdir(os.path.join(a_dir, name))]

    def _update_data(self):
        data = {}
        dirs = [{'name': name, 'path': path} for name, path in zip(self._names, self._train_dirs)]

        all_completed = True
        for dir_info in dirs:
            path = dir_info.get('path')
            content = self._update_data_from_dir(path)

            if not content:
                continue

            data[path] = {
                'path': path,
                'name': dir_info.get('name'),
                'content': content
            }

            passed_iterations = data[path]['content']['passed_iterations']
            total_iterations = data[path]['content']['total_iterations']
            all_completed &= (passed_iterations + 1 >= total_iterations and total_iterations != 0)

        if all_completed:
            self._need_to_stop.set()

        self.data = data

    def _update_data_from_dir(self, path):
        data = {
            'iterations': [],
            'meta': {}
        }

        training_json = os.path.join(path, 'catboost_training.json')

        if os.path.isfile(training_json):
            try:
                with open(training_json, 'r') as json_data:
                    training_data = json.load(json_data)
                    data['meta'] = training_data['meta']
                    data['iterations'] = training_data['iterations']
            except ValueError:
                pass

        return {
            'passed_iterations': data['iterations'][-1]['iteration'] if data['iterations'] else 0,
            'total_iterations': data['meta']['iteration_count'] if data['meta'] else 0,
            'data': data
        }

    @staticmethod
    def _get_static_path(file_name):
        return os.path.join(os.path.dirname(__file__), file_name)
