PyTorch Examples#

This section demonstrates how to use SACRO-ML with PyTorch models for privacy assessment.

Simple PyTorch Example#

A basic example showing how to train a simple PyTorch model and run privacy attacks.

Training the Model:

Simple PyTorch Model Training#
 1"""Train a classifier on synthetic data using sacroml Target and Dataset classes."""
 2
 3import logging
 4
 5import torch
 6from dataset import Synthetic
 7from model import OverfitNet
 8from train import test, train
 9
10from sacroml.attacks.target import Target
11
12target_dir = "target_pytorch"
13random_state = 2
14
15if __name__ == "__main__":
16    torch.manual_seed(random_state)
17    if torch.cuda.is_available():
18        torch.cuda.manual_seed_all(random_state)
19        logging.info(torch.cuda.get_device_name(torch.cuda.current_device()))
20    else:
21        logging.info("Found no NVIDIA driver on your system")
22
23    #############################################################################
24    # Dataset loading and model training
25    #############################################################################
26
27    logging.info("Loading dataset")
28
29    # Access dataset
30    data_handler = Synthetic()
31
32    # Get the (preprocessed) dataset
33    dataset = data_handler.get_dataset()
34
35    # Create data splits
36    indices_train, indices_test = data_handler.get_train_test_indices()
37
38    # Get dataloaders
39    train_loader = data_handler.get_dataloader(dataset, indices_train, shuffle=True)
40    test_loader = data_handler.get_dataloader(dataset, indices_test, shuffle=False)
41
42    logging.info("Defining the model")
43
44    model_params = {
45        "x_dim": 4,
46        "y_dim": 4,
47        "n_units": 1000,
48    }
49    train_params = {
50        "epochs": 1000,
51        "learning_rate": 0.001,
52        "momentum": 0.9,
53    }
54    model = OverfitNet(**model_params)
55
56    logging.info("Training the model")
57    train(model, train_loader, **train_params)
58
59    logging.info("Testing the model")
60    test(model, test_loader)
61
62    #############################################################################
63    # Below shows the use of the Target class to help generate the target_dir/
64    # If you have already saved your model, you can use the CLI target generator.
65    #############################################################################
66
67    logging.info("Wrapping the model and data in a Target object")
68    target = Target(
69        model=model,
70        model_module_path="model.py",
71        model_params=model_params,  # Must match all required in model constructor
72        train_module_path="train.py",
73        train_params=train_params,  # Must match all required in the train function
74        dataset_module_path="dataset.py",
75        dataset_name="Synthetic",  # Must match the class name in dataset module
76        indices_train=indices_train,
77        indices_test=indices_test,
78    )
79
80    logging.info("Writing Target object to directory: '%s'", target_dir)
81    target.save(target_dir)

Note: the training script above is included from examples/pytorch/simple/train_pytorch.py; any variables or helper functions it references (for example dataset preparation or target definitions) are defined in the example source files shown below.

Model Definition:

Simple PyTorch Model Architecture (from examples/pytorch/simple/model.py)#
 1"""An example Pytorch classifier."""
 2
 3import torch
 4
 5
 6class OverfitNet(torch.nn.Module):
 7    """An example Pytorch classification model."""
 8
 9    def __init__(self, x_dim: int, y_dim: int, n_units: int) -> None:
10        """Construct a simple Pytorch model."""
11        super().__init__()
12        self.layers = torch.nn.Sequential(
13            torch.nn.Linear(x_dim, n_units),
14            torch.nn.ReLU(),
15            torch.nn.Linear(n_units, y_dim),
16        )
17
18    def forward(self, x: torch.Tensor) -> torch.Tensor:
19        """Forward propagate input."""
20        return self.layers(x)

Note: this model code is included from examples/pytorch/simple/model.py. Definitions used by other snippets on this page (for example a variable named target or model class definitions) come from this file.

Running Privacy Attacks:

Privacy Attacks on Simple PyTorch Model (from examples/pytorch/simple/attack_pytorch.py)#
 1"""Example of how to run attacks on a model saved with the Target wrapper."""
 2
 3import logging
 4
 5from sacroml.attacks.likelihood_attack import LIRAAttack
 6from sacroml.attacks.target import Target
 7from sacroml.attacks.worst_case_attack import WorstCaseAttack
 8
 9output_dir = "output_pytorch"
10target_dir = "target_pytorch"
11
12
13if __name__ == "__main__":
14    logging.info("Loading Target object from '%s'", target_dir)
15
16    target = Target()
17    target.load(target_dir)
18
19    logging.info("Running attacks...")
20
21    attack = WorstCaseAttack(n_reps=10, output_dir=output_dir)
22    attack.attack(target)
23
24    attack = LIRAAttack(n_shadow_models=100, output_dir=output_dir)
25    attack.attack(target)

Note: the attack examples are included from examples/pytorch/simple/attack_pytorch.py and may call into the training or model files for required functions/objects.

CIFAR Dataset Example#

Advanced example using CIFAR dataset with convolutional neural networks.

Training the Model:

 1"""Train a PyTorch classifier on CIFAR10 using sacroml Target and Dataset classes."""
 2
 3import logging
 4
 5import torch
 6from dataset import Cifar10
 7from model import Net
 8from train import test, train
 9
10from sacroml.attacks.target import Target
11
12target_dir = "target_pytorch"
13random_state = 2
14
15if __name__ == "__main__":
16    torch.manual_seed(random_state)
17    if torch.cuda.is_available():
18        torch.cuda.manual_seed_all(random_state)
19        logging.info(torch.cuda.get_device_name(torch.cuda.current_device()))
20    else:
21        logging.info("Found no NVIDIA driver on your system")
22
23    #############################################################################
24    # Dataset loading and model training
25    #############################################################################
26
27    logging.info("Loading dataset")
28
29    # Access dataset
30    data_handler = Cifar10()
31
32    # Get the (preprocessed) dataset
33    dataset = data_handler.get_dataset()
34
35    # Create data splits
36    indices_train, indices_test = data_handler.get_train_test_indices()
37
38    # Get dataloaders
39    train_loader = data_handler.get_dataloader(dataset, indices_train, shuffle=True)
40    test_loader = data_handler.get_dataloader(dataset, indices_test, shuffle=False)
41
42    logging.info("Defining the model")
43
44    model_params = {
45        "n_kernel": 5,
46    }
47    train_params = {
48        "epochs": 100,
49        "learning_rate": 0.001,
50        "momentum": 0.9,
51    }
52    model = Net(**model_params)
53
54    logging.info("Training the model")
55    train(model, train_loader, **train_params)
56
57    logging.info("Testing the model")
58    test(model, test_loader, data_handler.classes)
59
60    #############################################################################
61    # Below shows the use of the Target class to help generate the target_dir/
62    # If you have already saved your model, you can use the CLI target generator.
63    #############################################################################
64
65    logging.info("Wrapping the model and data in a Target object")
66    target = Target(
67        model=model,
68        model_module_path="model.py",
69        model_params=model_params,  # Must match all required in model constructor
70        train_module_path="train.py",
71        train_params=train_params,  # Must match all required in the train function
72        dataset_module_path="dataset.py",
73        dataset_name="Cifar10",  # Must match the class name in dataset module
74        indices_train=indices_train,
75        indices_test=indices_test,
76    )
77
78    logging.info("Writing Target object to directory: '%s'", target_dir)
79    target.save(target_dir)
Note: the training script above is included from examples/pytorch/cifar/train_pytorch.py; dataset handling and model definitions referenced below are defined in their respective files.
caption:

CIFAR Dataset PyTorch Training (from examples/pytorch/cifar/train_pytorch.py)

Model Architecture:

CIFAR CNN Model Architecture (from examples/pytorch/cifar/model.py)#
 1"""An example Pytorch classifier."""
 2
 3import torch
 4from torch import nn
 5
 6
 7class Net(nn.Module):
 8    """A Pytorch classification model for cifar10."""
 9
10    def __init__(self, n_kernel: int = 5):
11        super().__init__()
12        self.conv1 = nn.Conv2d(3, 6, n_kernel)
13        self.pool = nn.MaxPool2d(2, 2)
14        self.conv2 = nn.Conv2d(6, 16, n_kernel)
15        self.fc1 = nn.Linear(16 * n_kernel * n_kernel, 120)
16        self.fc2 = nn.Linear(120, 84)
17        self.fc3 = nn.Linear(84, 10)
18
19    def forward(self, x):
20        """Forward propagate input."""
21        x = self.pool(torch.relu(self.conv1(x)))
22        x = self.pool(torch.relu(self.conv2(x)))
23        x = torch.flatten(x, 1)
24        x = torch.relu(self.fc1(x))
25        x = torch.relu(self.fc2(x))
26        return self.fc3(x)

Note: this model architecture is included from examples/pytorch/cifar/model.py and contains the network and related definitions used by the training script.

Dataset Processing:

CIFAR Dataset Processing (from examples/pytorch/cifar/dataset.py)#
 1"""Example dataset handler for CIFAR10.
 2
 3PyTorch datasets must implement `sacroml.attacks.data.PyTorchDataHandler`.
 4"""
 5
 6from collections.abc import Sequence
 7
 8from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset
 9from torchvision import transforms
10from torchvision.datasets import CIFAR10
11
12from sacroml.attacks.data import PyTorchDataHandler
13
14
15class Cifar10(PyTorchDataHandler):
16    """CIFAR10 dataset handler."""
17
18    def __init__(self) -> None:
19        """Fetch and process CIFAR10."""
20        self.transform = transforms.Compose(
21            [
22                transforms.ToTensor(),
23                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
24            ]
25        )
26
27        train_set = CIFAR10(
28            root="./data", train=True, download=True, transform=self.transform
29        )
30
31        test_set = CIFAR10(
32            root="./data", train=False, download=True, transform=self.transform
33        )
34
35        self.dataset = ConcatDataset([train_set, test_set])
36
37        self.classes = (
38            "plane",
39            "car",
40            "bird",
41            "cat",
42            "deer",
43            "dog",
44            "frog",
45            "horse",
46            "ship",
47            "truck",
48        )
49
50    def __len__(self) -> int:
51        """Return the length of the dataset."""
52        return len(self.dataset)
53
54    def get_raw_dataset(self) -> Dataset | None:
55        """Return a raw unprocessed dataset."""
56        # Raw data only required for attribute inference
57        return None
58
59    def get_dataset(self) -> Dataset:
60        """Return a preprocessed dataset."""
61        return self.dataset
62
63    def get_dataloader(
64        self,
65        dataset: Dataset,
66        indices: Sequence[int],
67        batch_size: int = 32,
68        shuffle: bool = False,
69    ) -> DataLoader:
70        """Return a data loader with a requested subset of samples."""
71        subset = Subset(dataset, indices)
72        return DataLoader(subset, batch_size=batch_size, shuffle=shuffle)
73
74    def get_train_test_indices(self) -> tuple[Sequence[int], Sequence[int]]:
75        """Return train and test set indices."""
76        train = range(50000)
77        test = range(50000, 60000)
78        return train, test

Note: dataset loading and preprocessing functions are provided in examples/pytorch/cifar/dataset.py; training and evaluation snippets reference these utilities.

Running Privacy Attacks:

Privacy Attacks on CIFAR Model (from examples/pytorch/cifar/attack_pytorch.py)#
 1"""Example of how to run attacks on a model saved with the Target wrapper."""
 2
 3import logging
 4
 5from sacroml.attacks.likelihood_attack import LIRAAttack
 6from sacroml.attacks.target import Target
 7from sacroml.attacks.worst_case_attack import WorstCaseAttack
 8
 9output_dir = "output_pytorch"
10target_dir = "target_pytorch"
11
12
13if __name__ == "__main__":
14    logging.info("Loading Target object from '%s'", target_dir)
15
16    target = Target()
17    target.load(target_dir)
18
19    logging.info("Running attacks...")
20
21    attack = WorstCaseAttack(n_reps=10, output_dir=output_dir)
22    attack.attack(target)
23
24    attack = LIRAAttack(n_shadow_models=40, output_dir=output_dir)
25    attack.attack(target)

Note: the attack code is taken from examples/pytorch/cifar/attack_pytorch.py and may depend on the model and dataset code linked above.