# pylint: disable=unbalanced-tuple-unpacking, too-many-locals
"""Tests for federated learning."""

import multiprocessing
import os
import subprocess
import tempfile
import time
from typing import List, cast

from sklearn.datasets import dump_svmlight_file, load_svmlight_file
from sklearn.model_selection import train_test_split

import xgboost as xgb
import xgboost.federated
from xgboost import testing as tm
from xgboost.training import TrainingCallback

SERVER_KEY = "server-key.pem"
SERVER_CERT = "server-cert.pem"
CLIENT_KEY = "client-key.pem"
CLIENT_CERT = "client-cert.pem"


def run_server(port: int, world_size: int, with_ssl: bool) -> None:
    """Run federated server for test."""
    if with_ssl:
        xgboost.federated.run_federated_server(
            world_size,
            port,
            server_key_path=SERVER_KEY,
            server_cert_path=SERVER_CERT,
            client_cert_path=CLIENT_CERT,
        )
    else:
        xgboost.federated.run_federated_server(world_size, port)


def run_worker(
    port: int, world_size: int, rank: int, with_ssl: bool, device: str
) -> None:
    """Run federated client worker for test."""
    communicator_env = {
        "dmlc_communicator": "federated",
        "federated_server_address": f"localhost:{port}",
        "federated_world_size": world_size,
        "federated_rank": rank,
    }
    if with_ssl:
        communicator_env["federated_server_cert_path"] = SERVER_CERT
        communicator_env["federated_client_key_path"] = CLIENT_KEY
        communicator_env["federated_client_cert_path"] = CLIENT_CERT

    cpu_count = os.cpu_count()
    assert cpu_count is not None
    n_threads = cpu_count // world_size

    # Always call this before using distributed module
    with xgb.collective.CommunicatorContext(**communicator_env):
        # Load file, file will not be sharded in federated mode.
        X, y = load_svmlight_file(f"agaricus.txt-{rank}.train")
        dtrain = xgb.DMatrix(X, y)
        X, y = load_svmlight_file(f"agaricus.txt-{rank}.test")
        dtest = xgb.DMatrix(X, y)

        # Specify parameters via map, definition are same as c++ version
        param = {
            "max_depth": 2,
            "eta": 1,
            "objective": "binary:logistic",
            "nthread": n_threads,
            "tree_method": "hist",
            "device": device,
        }

        # Specify validations set to watch performance
        watchlist = [(dtest, "eval"), (dtrain, "train")]
        num_round = 20

        # Run training, all the features in training API is available.
        results: TrainingCallback.EvalsLog = {}
        bst = xgb.train(
            param,
            dtrain,
            num_round,
            evals=watchlist,
            early_stopping_rounds=2,
            evals_result=results,
        )
        assert tm.non_increasing(cast(List[float], results["train"]["logloss"]))
        assert tm.non_increasing(cast(List[float], results["eval"]["logloss"]))

        # save the model, only ask process 0 to save the model.
        if xgb.collective.get_rank() == 0:
            with tempfile.TemporaryDirectory() as tmpdir:
                bst.save_model(os.path.join(tmpdir, "model.json"))
            xgb.collective.communicator_print("Finished training\n")


def run_federated(world_size: int, with_ssl: bool, use_gpu: bool) -> None:
    """Launcher for clients and the server."""
    port = 9091

    server = multiprocessing.Process(
        target=run_server, args=(port, world_size, with_ssl)
    )
    server.start()
    time.sleep(1)
    if not server.is_alive():
        raise ValueError("Error starting Federated Learning server")

    workers = []
    for rank in range(world_size):
        device = f"cuda:{rank}" if use_gpu else "cpu"
        worker = multiprocessing.Process(
            target=run_worker, args=(port, world_size, rank, with_ssl, device)
        )
        workers.append(worker)
        worker.start()
    for worker in workers:
        worker.join()
    server.terminate()


def run_federated_learning(with_ssl: bool, use_gpu: bool, test_path: str) -> None:
    """Run federated learning tests."""
    n_workers = 2

    if with_ssl:
        command = "openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout {part}-key.pem -out {part}-cert.pem -subj /C=US/CN=localhost"  # pylint: disable=line-too-long
        server_key = command.format(part="server").split()
        subprocess.check_call(server_key)
        client_key = command.format(part="client").split()
        subprocess.check_call(client_key)

    train_path = os.path.join(tm.data_dir(test_path), "agaricus.txt.train")
    test_path = os.path.join(tm.data_dir(test_path), "agaricus.txt.test")

    X_train, y_train = load_svmlight_file(train_path)
    X_test, y_test = load_svmlight_file(test_path)

    X0, X1, y0, y1 = train_test_split(X_train, y_train, test_size=0.5)
    X0_valid, X1_valid, y0_valid, y1_valid = train_test_split(
        X_test, y_test, test_size=0.5
    )

    dump_svmlight_file(X0, y0, "agaricus.txt-0.train")
    dump_svmlight_file(X0_valid, y0_valid, "agaricus.txt-0.test")

    dump_svmlight_file(X1, y1, "agaricus.txt-1.train")
    dump_svmlight_file(X1_valid, y1_valid, "agaricus.txt-1.test")

    run_federated(world_size=n_workers, with_ssl=with_ssl, use_gpu=use_gpu)
