"""Run a worst case attack based upon predictive probabilities."""
from __future__ import annotations
import logging
from collections.abc import Iterable
import numpy as np
from fpdf import FPDF
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sacroml import metrics
from sacroml.attacks import report
from sacroml.attacks.attack import Attack
from sacroml.attacks.target import Target
from sacroml.attacks.utils import get_class_by_name
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
P_THRESH = 0.05
[docs]
class WorstCaseAttack(Attack):
    """Worst case attack."""
[docs]
    def __init__(
        self,
        output_dir: str = "outputs",
        write_report: bool = True,
        n_reps: int = 10,
        reproduce_split: int | Iterable[int] | None = 5,
        p_thresh: float = 0.05,
        n_dummy_reps: int = 1,
        train_beta: int = 1,
        test_beta: int = 1,
        test_prop: float = 0.2,
        include_model_correct_feature: bool = False,
        sort_probs: bool = True,
        attack_model: str = "sklearn.ensemble.RandomForestClassifier",
        attack_model_params: dict | None = None,
    ) -> None:
        """Construct an object to execute a worst case attack.
        Parameters
        ----------
        output_dir : str
            Name of the directory where outputs are stored.
        write_report : bool
            Whether to generate a JSON and PDF report.
        n_reps : int
            Number of attacks to run -- in each iteration an attack model
            is trained on a different subset of the data.
        reproduce_split : int or Iterable[int] or None
            Variable that controls the reproducibility of the data split.
            It can be an integer or a list of integers of length `n_reps`.
            Default : 5.
        p_thresh : float
            Threshold to determine significance of things. For instance
            `auc_p_value` and `pdif_vals`.
        n_dummy_reps : int
            Number of baseline (dummy) experiments to do.
        train_beta : int
            Value of b for beta distribution used to sample the in-sample
            (training) probabilities.
        test_beta : int
            Value of b for beta distribution used to sample the out-of-sample
            (test) probabilities.
        test_prop : float
            Proportion of data to use as a test set for the attack model.
        include_model_correct_feature : bool
            Inclusion of additional feature to hold whether or not the target model
            made a correct prediction for each example.
        sort_probs : bool
            Whether to sort combined preds (from training and test)
            to have highest probabilities in the first column.
        attack_model : str
            Class name of the attack model.
        attack_model_params : dict or None
            Dictionary of hyperparameters for the `attack_model`
            such as `min_sample_split`, `min_samples_leaf`, etc.
        """
        super().__init__(output_dir=output_dir, write_report=write_report)
        self.n_reps: int = n_reps
        self.reproduce_split: int | Iterable[int] | None = reproduce_split
        self.p_thresh: float = p_thresh
        self.n_dummy_reps: int = n_dummy_reps
        self.train_beta: int = train_beta
        self.test_beta: int = test_beta
        self.test_prop: float = test_prop
        self.include_model_correct_feature: bool = include_model_correct_feature
        self.sort_probs: bool = sort_probs
        self.attack_model: str = attack_model
        self.attack_model_params: dict | None = attack_model_params
        self.dummy_attack_metrics: list = [] 
    def __str__(self) -> str:
        """Return name of attack."""
        return "WorstCase attack"
[docs]
    @classmethod
    def attackable(cls, target: Target) -> bool:  # pragma: no cover
        """Return whether a target can be assessed with WorstCaseAttack."""
        required_methods = ["predict_proba", "predict"]
        if (
            target.has_model()
            and target.has_data()
            and all(hasattr(target.model, method) for method in required_methods)
        ) or target.has_probas():
            return True
        logger.info("WARNING: WorstCaseAttack requires more Target details.")
        return False 
    def _attack(self, target: Target) -> dict:
        """Run worst case attack.
        Parameters
        ----------
        target : attacks.target.Target
            target as a Target class object
        Returns
        -------
        dict
            Attack report.
        """
        train_c = None
        test_c = None
        # compute target model probas if possible
        if target.has_model() and target.has_data():  # pragma: no cover
            proba_train = target.model.predict_proba(target.X_train)
            proba_test = target.model.predict_proba(target.X_test)
            if self.include_model_correct_feature:
                train_c = 1 * (target.y_train == target.model.predict(target.X_train))
                test_c = 1 * (target.y_test == target.model.predict(target.X_test))
        # use supplied target model probas if unable to compute
        else:
            proba_train = target.proba_train
            proba_test = target.proba_test
        # execute attack
        self.attack_from_preds(
            proba_train,
            proba_test,
            train_correct=train_c,
            test_correct=test_c,
        )
        # create the report
        output = self._make_report(target)
        # write the report
        self._write_report(output)
        # return the report
        return output
    def _make_report(self, target: Target) -> dict:
        """Create attack report."""
        output = super()._make_report(target)
        output["dummy_attack_experiments_logger"] = (
            self._get_dummy_attack_metrics_experiments_instances()
        )
        return output
[docs]
    def attack_from_preds(
        self,
        proba_train: np.ndarray,
        proba_test: np.ndarray,
        train_correct: np.ndarray | None = None,
        test_correct: np.ndarray | None = None,
    ) -> None:
        """Run attack based upon the predictions in proba_train and proba_test.
        Parameters
        ----------
        proba_train : np.ndarray
            Array of train predictions. One row per example, one column per class.
        proba_test : np.ndarray
            Array of test predictions. One row per example, one column per class.
        """
        logger.info("Running main attack repetitions")
        attack_metric_dict = self.run_attack_reps(
            proba_train,
            proba_test,
            train_correct=train_correct,
            test_correct=test_correct,
        )
        self.attack_metrics = attack_metric_dict["mia_metrics"]
        self.dummy_attack_metrics = []
        if self.n_dummy_reps > 0:
            logger.info("Running dummy attack reps")
            n_train_rows = len(proba_train)
            n_test_rows = len(proba_test)
            for _ in range(self.n_dummy_reps):
                d_train_preds, d_test_preds = self.generate_arrays(
                    n_train_rows,
                    n_test_rows,
                    self.train_beta,
                    self.test_beta,
                )
                temp_attack_metric_dict = self.run_attack_reps(
                    d_train_preds, d_test_preds
                )
                temp_metrics = temp_attack_metric_dict["mia_metrics"]
                self.dummy_attack_metrics.append(temp_metrics)
        logger.info("Finished running attacks") 
    def _prepare_attack_data(
        self,
        proba_train: np.ndarray,
        proba_test: np.ndarray,
        train_correct: np.ndarray = None,
        test_correct: np.ndarray = None,
    ) -> tuple[np.ndarray, np.ndarray]:
        """Prepare training data and labels for attack model.
        Combines the train and test preds into a single numpy array
        (optionally) sorting each row to have the highest probabilities in the
        first column. Constructs a label array that has ones corresponding to
        training rows and zeros to testing rows.
        """
        if self.sort_probs:
            logger.info("Sorting probabilities to leave highest value in first column")
            proba_train = -np.sort(-proba_train, axis=1)
            proba_test = -np.sort(-proba_test, axis=1)
        logger.info("Creating MIA data")
        if self.include_model_correct_feature and train_correct is not None:
            proba_train = np.hstack((proba_train, train_correct[:, None]))
            proba_test = np.hstack((proba_test, test_correct[:, None]))
        mi_x = np.vstack((proba_train, proba_test))
        mi_y = np.hstack((np.ones(len(proba_train)), np.zeros(len(proba_test))))
        return (mi_x, mi_y)
    def _get_attack_model(self):
        """Return an instantiated attack model."""
        # load attack model module and get class
        model = get_class_by_name(self.attack_model)
        params = self.attack_model_params
        if (  # set custom default parameters for RF attack model
            self.attack_model == "sklearn.ensemble.RandomForestClassifier"
            and self.attack_model_params is None
        ):
            params = {
                "min_samples_split": 20,
                "min_samples_leaf": 10,
                "max_depth": 5,
            }
        # instantiate attack model
        return model(**params) if params is not None else model()
    def _get_reproducible_split(self) -> list:
        """Return a list of splits."""
        split = self.reproduce_split
        n_reps = self.n_reps
        if isinstance(split, int):
            split = [split] + [x**2 for x in range(split, split + n_reps - 1)]
        else:
            # remove potential duplicates
            split = list(dict.fromkeys(split))
            if len(split) == n_reps:
                pass
            elif len(split) > n_reps:
                print("split", split, "nreps", n_reps)
                split = list(split)[0:n_reps]
                print(
                    "WARNING: the length of the parameter 'reproduce_split' "
                    "is longer than n_reps. Values have been removed."
                )
            else:
                # assign values to match length of n_reps
                split += [split[-1] * x for x in range(2, (n_reps - len(split) + 2))]
                print(
                    "WARNING: the length of the parameter 'reproduce_split' "
                    "is shorter than n_reps. Values have been added."
                )
            print("reproduce split now", split)
        return split
[docs]
    def run_attack_reps(
        self,
        proba_train: np.ndarray,
        proba_test: np.ndarray,
        train_correct: np.ndarray = None,
        test_correct: np.ndarray = None,
    ) -> dict:
        """Run actual attack reps from train and test predictions.
        Parameters
        ----------
        proba_train : np.ndarray
            Predictions from the model on training (in-sample) data.
        proba_test : np.ndarray
            Predictions from the model on testing (out-of-sample) data.
        Returns
        -------
        dict
            Dictionary of mia_metrics (a list of metric across repetitions).
        """
        mi_x, mi_y = self._prepare_attack_data(
            proba_train, proba_test, train_correct, test_correct
        )
        mia_metrics = []
        split = self._get_reproducible_split()
        for rep in range(self.n_reps):
            logger.info("Rep %d of %d split %d", rep + 1, self.n_reps, split[rep])
            mi_train_x, mi_test_x, mi_train_y, mi_test_y = train_test_split(
                mi_x,
                mi_y,
                test_size=self.test_prop,
                stratify=mi_y,
                random_state=split[rep],
                shuffle=True,
            )
            attack_classifier = self._get_attack_model()
            attack_classifier.fit(mi_train_x, mi_train_y)
            y_pred_proba = attack_classifier.predict_proba(mi_test_x)
            mia_metrics.append(metrics.get_metrics(y_pred_proba, mi_test_y))
            if self.include_model_correct_feature and train_correct is not None:
                # Compute the Yeom TPR and FPR
                yeom_preds = mi_test_x[:, -1]
                tn, fp, fn, tp = confusion_matrix(mi_test_y, yeom_preds).ravel()
                mia_metrics[-1]["yeom_tpr"] = tp / (tp + fn)
                mia_metrics[-1]["yeom_fpr"] = fp / (fp + tn)
                mia_metrics[-1]["yeom_advantage"] = (
                    mia_metrics[-1]["yeom_tpr"] - mia_metrics[-1]["yeom_fpr"]
                )
        logger.info("Finished simulating attacks")
        return {"mia_metrics": mia_metrics} 
    def _get_global_metrics(self, attack_metrics: list) -> dict:
        """Summarise metrics from a metric list.
        Parameters
        ----------
        attack_metrics : List
            list of attack metrics dictionaries
        Returns
        -------
        global_metrics : Dict
            Dictionary of summary metrics
        """
        global_metrics = {}
        if attack_metrics is not None and len(attack_metrics) != 0:
            auc_p_vals = [
                metrics.auc_p_val(
                    m["AUC"], m["n_pos_test_examples"], m["n_neg_test_examples"]
                )[0]
                for m in attack_metrics
            ]
            m = attack_metrics[0]
            _, auc_std = metrics.auc_p_val(
                0.5, m["n_pos_test_examples"], m["n_neg_test_examples"]
            )
            global_metrics["null_auc_3sd_range"] = (
                f"{0.5 - 3 * auc_std:.4f} -> {0.5 + 3 * auc_std:.4f}"
            )
            global_metrics["n_sig_auc_p_vals"] = self._get_n_significant(
                auc_p_vals, self.p_thresh
            )
            global_metrics["n_sig_auc_p_vals_corrected"] = self._get_n_significant(
                auc_p_vals, self.p_thresh, bh_fdr_correction=True
            )
            pdif_vals = [np.exp(-m["PDIF01"]) for m in attack_metrics]
            global_metrics["n_sig_pdif_vals"] = self._get_n_significant(
                pdif_vals, self.p_thresh
            )
            global_metrics["n_sig_pdif_vals_corrected"] = self._get_n_significant(
                pdif_vals, self.p_thresh, bh_fdr_correction=True
            )
        return global_metrics
    def _get_n_significant(
        self, p_val_list: list[float], p_thresh: float, bh_fdr_correction: bool = False
    ) -> int:
        """Return number of p-values significant at `p_thresh`.
        Can perform multiple testing correction.
        """
        if not bh_fdr_correction:
            return sum(1 for p in p_val_list if p <= p_thresh)
        p_val_list = np.asarray(sorted(p_val_list))
        n_vals = len(p_val_list)
        hoch_vals = np.array([(k / n_vals) * P_THRESH for k in range(1, n_vals + 1)])
        bh_sig_list = p_val_list <= hoch_vals
        return np.where(bh_sig_list)[0].max() + 1 if any(bh_sig_list) else 0
    def _generate_array(self, n_rows: int, beta: float) -> np.ndarray:
        """Generate array of predictions, used when doing baseline experiments.
        Parameters
        ----------
        n_rows : int
            The number of rows worth of data to generate.
        beta : float
            The beta parameter for sampling probabilities.
        Returns
        -------
        preds : np.ndarray
            Array of predictions. Two columns, `n_rows` rows.
        """
        preds = np.zeros((n_rows, 2), float)
        for row_idx in range(n_rows):
            train_class = np.random.choice(2)
            train_prob = np.random.beta(1, beta)
            preds[row_idx, train_class] = train_prob
            preds[row_idx, 1 - train_class] = 1 - train_prob
        return preds
[docs]
    def generate_arrays(
        self,
        n_rows_in: int,
        n_rows_out: int,
        train_beta: float = 2,
        test_beta: float = 2,
    ) -> tuple[np.ndarray, np.ndarray]:
        """Generate train and test prediction arrays, used when computing baseline.
        Parameters
        ----------
        n_rows_in : int
            Number of rows of in-sample (training) probabilities.
        n_rows_out : int
            Number of rows of out-of-sample (testing) probabilities.
        train_beta : float
            Beta value for generating train probabilities.
        test_beta : float:
            Beta value for generating test probabilities.
        Returns
        -------
        proba_train : np.ndarray
            Array of train predictions (n_rows x 2 columns).
        proba_test : np.ndarray
            Array of test predictions (n_rows x 2 columns).
        """
        proba_train = self._generate_array(n_rows_in, train_beta)
        proba_test = self._generate_array(n_rows_out, test_beta)
        return proba_train, proba_test 
    def _construct_metadata(self) -> None:
        """Construct the metadata object after attacks."""
        super()._construct_metadata()
        self.metadata["global_metrics"] = self._get_global_metrics(self.attack_metrics)
        self.metadata["baseline_global_metrics"] = self._get_global_metrics(
            self._unpack_dummy_attack_metrics_experiments_instances()
        )
    def _unpack_dummy_attack_metrics_experiments_instances(self) -> list:
        """Construct the metadata object after attacks."""
        dummy_attack_metrics_instances = []
        for exp_rep, _ in enumerate(self.dummy_attack_metrics):
            temp_dummy_attack_metrics = self.dummy_attack_metrics[exp_rep]
            dummy_attack_metrics_instances += temp_dummy_attack_metrics
        return dummy_attack_metrics_instances
    def _get_attack_metrics_instances(self) -> dict:
        """Construct the metadata object after attacks."""
        attack_metrics_experiment = {}
        attack_metrics_instances = {}
        for rep, _ in enumerate(self.attack_metrics):
            attack_metrics_instances["instance_" + str(rep)] = self.attack_metrics[rep]
        attack_metrics_experiment["attack_instance_logger"] = attack_metrics_instances
        return attack_metrics_experiment
    def _get_dummy_attack_metrics_experiments_instances(self) -> dict:
        """Construct the metadata object after attacks."""
        dummy_attack_metrics_experiments = {}
        for exp_rep, _ in enumerate(self.dummy_attack_metrics):
            temp_dummy_attack_metrics = self.dummy_attack_metrics[exp_rep]
            dummy_attack_metric_instances = {}
            for rep, _ in enumerate(temp_dummy_attack_metrics):
                dummy_attack_metric_instances["instance_" + str(rep)] = (
                    temp_dummy_attack_metrics[rep]
                )
            temp = {}
            temp["attack_instance_logger"] = dummy_attack_metric_instances
            dummy_attack_metrics_experiments[
                "dummy_attack_metrics_experiment_" + str(exp_rep)
            ] = temp
        return dummy_attack_metrics_experiments
    def _make_pdf(self, output: dict) -> FPDF:
        """Create PDF report."""
        return report.create_mia_report(output)