"""Store information about the target model and data."""
from __future__ import annotations
import logging
import os
import pickle
import shutil
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import pandas as pd
import sklearn
import torch
import yaml
from sacroml.attacks.model_pytorch import PytorchModel
from sacroml.attacks.model_sklearn import SklearnModel
MODEL_REGISTRY: dict[str, Any] = {
"PytorchModel": PytorchModel,
"SklearnModel": SklearnModel,
}
DATA_ATTRIBUTES: list[str] = [
"X_train",
"y_train",
"X_test",
"y_test",
"X_train_orig",
"y_train_orig",
"X_test_orig",
"y_test_orig",
"proba_train",
"proba_test",
]
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
[docs]
@dataclass
class Target: # pylint: disable=too-many-instance-attributes
"""Store information about the target model and data.
Attributes
----------
model : Any
Trained target model.
model_path : str
Path to a saved model.
model_module_path : str
Path to module containing model class.
model_name : str
Class name of model.
model_params : dict or None
Hyperparameters for instantiating the model.
train_module_path : str
Path to module containing training function.
train_params : dict or None
Hyperparameters for training the model.
dataset_name : str
The name of the dataset.
dataset_module_path : str
Path to module containing dataset loading function.
features : dict
Dictionary describing the dataset features.
X_train : np.ndarray or None
The (processed) training inputs.
y_train : np.ndarray or None
The (processed) training outputs.
X_test : np.ndarray or None
The (processed) testing inputs.
y_test : np.ndarray or None
The (processed) testing outputs.
X_train_orig : np.ndarray or None
The original (unprocessed) training inputs.
y_train_orig : np.ndarray or None
The original (unprocessed) training outputs.
X_test_orig : np.ndarray or None
The original (unprocessed) testing inputs.
y_test_orig : np.ndarray or None
The original (unprocessed) testing outputs.
proba_train : np.ndarray or None
The model predicted training probabilities.
proba_test : np.ndarray or None
The model predicted testing probabilities.
safemodel : list
Results of safemodel disclosure checking.
"""
# Model attributes
model: Any = None
model_path: str = ""
model_module_path: str = ""
model_name: str = ""
model_params: dict | None = None
train_module_path: str = ""
train_params: dict | None = None
# Dataset attributes
dataset_name: str = ""
dataset_module_path: str = ""
features: dict = field(default_factory=dict)
# Data arrays
X_train: np.ndarray | None = None
y_train: np.ndarray | None = None
X_test: np.ndarray | None = None
y_test: np.ndarray | None = None
X_train_orig: np.ndarray | None = None
y_train_orig: np.ndarray | None = None
X_test_orig: np.ndarray | None = None
y_test_orig: np.ndarray | None = None
proba_train: np.ndarray | None = None
proba_test: np.ndarray | None = None
# Safemodel properties
safemodel: list = field(default_factory=list)
def __post_init__(self):
"""Initialise the model wrapper after dataclass creation."""
self.model = self._wrap_model(self.model)
def _wrap_model(self, model: Any) -> Any:
"""Wrap the model in a wrapper class."""
if model is None:
return None
if isinstance(model, sklearn.base.BaseEstimator):
return SklearnModel(
model=model,
model_path=self.model_path,
model_module_path=self.model_module_path,
model_name=self.model_name,
model_params=self.model_params,
train_module_path=self.train_module_path,
train_params=self.train_params,
)
if isinstance(model, torch.nn.Module):
return PytorchModel(
model=model,
model_path=self.model_path,
model_module_path=self.model_module_path,
model_name=self.model_name,
model_params=self.model_params,
train_module_path=self.train_module_path,
train_params=self.train_params,
)
if isinstance(model, (SklearnModel, PytorchModel)):
return model
raise ValueError(f"Unsupported model type: {type(model)}") # pragma: no cover
@property
def n_features(self) -> int:
"""Number of features."""
return len(self.features)
[docs]
def add_feature(self, name: str, indices: list[int], encoding: str) -> None:
"""Add a feature description to the data dictionary."""
index: int = len(self.features)
self.features[index] = {
"name": name,
"indices": indices,
"encoding": encoding,
}
[docs]
def add_safemodel_results(self, data: list) -> None:
"""Add safemodel disclosure checking results."""
self.safemodel = data
[docs]
def has_model(self) -> bool:
"""Return whether the target has a loaded model."""
return self.model is not None and self.model.model is not None
[docs]
def has_data(self) -> bool:
"""Return whether the target has all processed data."""
attrs: list[str] = ["X_train", "y_train", "X_test", "y_test"]
return all(getattr(self, attr) is not None for attr in attrs)
[docs]
def has_raw_data(self) -> bool:
"""Return whether the target has all raw data."""
attrs: list[str] = [
"X_train_orig",
"y_train_orig",
"X_test_orig",
"y_test_orig",
]
return all(getattr(self, attr) is not None for attr in attrs)
[docs]
def has_probas(self) -> bool:
"""Return whether the target has all probability data."""
return self.proba_train is not None and self.proba_test is not None
[docs]
def get_generalisation_error(self) -> float:
"""Calculate model generalisation error."""
if not (self.has_model() and self.has_data()):
return np.nan
return self.model.get_generalisation_error(
self.X_train, self.y_train, self.X_test, self.y_test
)
[docs]
def save(self, path: str = "target", ext: str = "pkl") -> None:
"""Save target to persistent storage."""
path = os.path.normpath(path)
os.makedirs(path, exist_ok=True)
# Create target dictionary
target = {
"dataset_name": self.dataset_name,
"dataset_module_path": self.dataset_module_path,
"features": self.features,
"generalisation_error": self.get_generalisation_error(),
"safemodel": self.safemodel,
}
# Save model
self._save_model(path, ext, target)
# Save dataset module
self._save_dataset_module(path, target)
# Save data arrays
for attr in DATA_ATTRIBUTES:
self._save_array(path, target, attr)
# Save YAML config
yaml_path = os.path.join(path, "target.yaml")
with open(yaml_path, "w", encoding="utf-8") as f:
yaml.dump(target, f, default_flow_style=False, sort_keys=False)
[docs]
def load(self, path: str = "target") -> None:
"""Load target from persistent storage."""
yaml_path = os.path.join(path, "target.yaml")
with open(yaml_path, encoding="utf-8") as f:
target = yaml.safe_load(f)
# Load basic attributes
for attr in ["dataset_name", "safemodel"]:
if attr in target:
setattr(self, attr, target[attr])
# Load features (convert string keys to int)
if "features" in target:
self.features = {int(k): v for k, v in target["features"].items()}
# Load paths
if "dataset_module_path" in target:
self.dataset_module_path = os.path.join(path, target["dataset_module_path"])
# Load model and data
self._load_model(path, target)
for attr in DATA_ATTRIBUTES:
self._load_array(path, target, attr)
def _save_model(self, path: str, ext: str, target: dict) -> None:
"""Save model to disk."""
if self.model is None: # pragma: no cover
return
target.update(
{
"model_type": self.model.model_type,
"model_name": self.model.model_name,
"model_params": self.model.get_params(),
}
)
# Copy module files
if self.model_module_path:
shutil.copy2(self.model_module_path, os.path.join(path, "model.py"))
target["model_module_path"] = "model.py"
if getattr(self.model, "train_module_path", ""):
shutil.copy2(self.model.train_module_path, os.path.join(path, "train.py"))
target["train_module_path"] = "train.py"
target["train_params"] = self.model.train_params
# Save model
model_path = os.path.join(path, f"model.{ext}")
self.model.save(model_path)
target["model_path"] = f"model.{ext}"
def _save_dataset_module(self, path: str, target: dict) -> None:
"""Save dataset module."""
if self.dataset_module_path: # pragma: no cover
shutil.copy2(self.dataset_module_path, os.path.join(path, "dataset.py"))
target["dataset_module_path"] = "dataset.py"
def _save_array(self, path: str, target: dict, attr_name: str) -> None:
"""Save numpy array as pickle."""
arr = getattr(self, attr_name)
if arr is not None:
arr_path = os.path.join(path, f"{attr_name}.pkl")
with open(arr_path, "wb") as f:
pickle.dump(arr, f, protocol=pickle.HIGHEST_PROTOCOL)
target[f"{attr_name}_path"] = f"{attr_name}.pkl"
else:
target[f"{attr_name}_path"] = ""
def _load_model(self, path: str, target: dict) -> None:
"""Load model from disk."""
model_type = target.get("model_type", "")
if not model_type or model_type not in MODEL_REGISTRY: # pragma: no cover
logger.info("Cannot load model: %s", model_type)
return
model_class = MODEL_REGISTRY[model_type]
self.model = model_class.load(
model_path=os.path.join(path, target.get("model_path", "")),
model_module_path=os.path.join(path, target.get("model_module_path", "")),
model_name=target.get("model_name", ""),
model_params=target.get("model_params", {}),
train_module_path=os.path.join(path, target.get("train_module_path", "")),
train_params=target.get("train_params", {}),
)
logger.info("Loaded: %s : %s", model_type, target.get("model_name", ""))
def _load_array(self, path: str, target: dict, attr_name: str) -> None:
"""Load array from disk."""
path_key = f"{attr_name}_path"
if path_key in target and target[path_key]:
arr_path = os.path.join(path, target[path_key])
self.load_array(arr_path, attr_name)
[docs]
def load_array(self, arr_path: str, attr_name: str) -> None:
"""Load array from pickle or CSV file."""
_, ext = os.path.splitext(arr_path)
if ext == ".pkl":
arr = self._load_pickle(arr_path, attr_name)
elif ext == ".csv": # pragma: no cover
arr = self._load_csv(arr_path, attr_name)
else: # pragma: no cover
raise ValueError(f"Unsupported file extension: {ext}")
setattr(self, attr_name, arr)
def _load_pickle(self, path: str, name: str) -> np.ndarray: # pragma: no cover
"""Load array from pickle file."""
try:
with open(path, "rb") as f:
arr = pickle.load(f)
if hasattr(arr, "shape"):
logger.info("%s shape: %s", name, arr.shape)
else:
logger.info("%s is a scalar value", name)
return arr
except FileNotFoundError as e:
raise FileNotFoundError(f"Pickle file not found: {path}") from e
except Exception as e:
raise ValueError(f"Error loading pickle file {path}: {e}") from e
def _load_csv(self, path: str, name: str) -> np.ndarray: # pragma: no cover
"""Load array from CSV file."""
try:
arr = pd.read_csv(path, header=None).values
logger.info("%s shape: %s", name, arr.shape)
return arr
except FileNotFoundError as e:
raise FileNotFoundError(f"CSV file not found: {path}") from e
except pd.errors.EmptyDataError as e:
raise ValueError(f"CSV file is empty: {path}") from e
except Exception as e:
raise ValueError(f"Error reading CSV file {path}: {e}") from e