PyTorch Examples#
This section demonstrates how to use SACRO-ML with PyTorch models for privacy assessment.
Simple PyTorch Example#
A basic example showing how to train a simple PyTorch model and run privacy attacks.
Training the Model:
1"""Train a classifier on synthetic data using sacroml Target and Dataset classes."""
2
3import logging
4
5import torch
6from dataset import Synthetic
7from model import OverfitNet
8from train import test, train
9
10from sacroml.attacks.target import Target
11
12target_dir = "target_pytorch"
13random_state = 2
14
15if __name__ == "__main__":
16 torch.manual_seed(random_state)
17 if torch.cuda.is_available():
18 torch.cuda.manual_seed_all(random_state)
19 logging.info(torch.cuda.get_device_name(torch.cuda.current_device()))
20 else:
21 logging.info("Found no NVIDIA driver on your system")
22
23 #############################################################################
24 # Dataset loading and model training
25 #############################################################################
26
27 logging.info("Loading dataset")
28
29 # Access dataset
30 data_handler = Synthetic()
31
32 # Get the (preprocessed) dataset
33 dataset = data_handler.get_dataset()
34
35 # Create data splits
36 indices_train, indices_test = data_handler.get_train_test_indices()
37
38 # Get dataloaders
39 train_loader = data_handler.get_dataloader(dataset, indices_train, shuffle=True)
40 test_loader = data_handler.get_dataloader(dataset, indices_test, shuffle=False)
41
42 logging.info("Defining the model")
43
44 model_params = {
45 "x_dim": 4,
46 "y_dim": 4,
47 "n_units": 1000,
48 }
49 train_params = {
50 "epochs": 1000,
51 "learning_rate": 0.001,
52 "momentum": 0.9,
53 }
54 model = OverfitNet(**model_params)
55
56 logging.info("Training the model")
57 train(model, train_loader, **train_params)
58
59 logging.info("Testing the model")
60 test(model, test_loader)
61
62 #############################################################################
63 # Below shows the use of the Target class to help generate the target_dir/
64 # If you have already saved your model, you can use the CLI target generator.
65 #############################################################################
66
67 logging.info("Wrapping the model and data in a Target object")
68 target = Target(
69 model=model,
70 model_module_path="model.py",
71 model_params=model_params, # Must match all required in model constructor
72 train_module_path="train.py",
73 train_params=train_params, # Must match all required in the train function
74 dataset_module_path="dataset.py",
75 dataset_name="Synthetic", # Must match the class name in dataset module
76 indices_train=indices_train,
77 indices_test=indices_test,
78 )
79
80 logging.info("Writing Target object to directory: '%s'", target_dir)
81 target.save(target_dir)
Note: the training script above is included from examples/pytorch/simple/train_pytorch.py; any variables or helper functions it references (for example dataset preparation or target definitions) are defined in the example source files shown below.
Model Definition:
1"""An example Pytorch classifier."""
2
3import torch
4
5
6class OverfitNet(torch.nn.Module):
7 """An example Pytorch classification model."""
8
9 def __init__(self, x_dim: int, y_dim: int, n_units: int) -> None:
10 """Construct a simple Pytorch model."""
11 super().__init__()
12 self.layers = torch.nn.Sequential(
13 torch.nn.Linear(x_dim, n_units),
14 torch.nn.ReLU(),
15 torch.nn.Linear(n_units, y_dim),
16 )
17
18 def forward(self, x: torch.Tensor) -> torch.Tensor:
19 """Forward propagate input."""
20 return self.layers(x)
Note: this model code is included from examples/pytorch/simple/model.py. Definitions used by other snippets on this page (for example a variable named target or model class definitions) come from this file.
Running Privacy Attacks:
1"""Example of how to run attacks on a model saved with the Target wrapper."""
2
3import logging
4
5from sacroml.attacks.likelihood_attack import LIRAAttack
6from sacroml.attacks.target import Target
7from sacroml.attacks.worst_case_attack import WorstCaseAttack
8
9output_dir = "output_pytorch"
10target_dir = "target_pytorch"
11
12
13if __name__ == "__main__":
14 logging.info("Loading Target object from '%s'", target_dir)
15
16 target = Target()
17 target.load(target_dir)
18
19 logging.info("Running attacks...")
20
21 attack = WorstCaseAttack(n_reps=10, output_dir=output_dir)
22 attack.attack(target)
23
24 attack = LIRAAttack(n_shadow_models=100, output_dir=output_dir)
25 attack.attack(target)
Note: the attack examples are included from examples/pytorch/simple/attack_pytorch.py and may call into the training or model files for required functions/objects.
CIFAR Dataset Example#
Advanced example using CIFAR dataset with convolutional neural networks.
Training the Model:
1"""Train a PyTorch classifier on CIFAR10 using sacroml Target and Dataset classes."""
2
3import logging
4
5import torch
6from dataset import Cifar10
7from model import Net
8from train import test, train
9
10from sacroml.attacks.target import Target
11
12target_dir = "target_pytorch"
13random_state = 2
14
15if __name__ == "__main__":
16 torch.manual_seed(random_state)
17 if torch.cuda.is_available():
18 torch.cuda.manual_seed_all(random_state)
19 logging.info(torch.cuda.get_device_name(torch.cuda.current_device()))
20 else:
21 logging.info("Found no NVIDIA driver on your system")
22
23 #############################################################################
24 # Dataset loading and model training
25 #############################################################################
26
27 logging.info("Loading dataset")
28
29 # Access dataset
30 data_handler = Cifar10()
31
32 # Get the (preprocessed) dataset
33 dataset = data_handler.get_dataset()
34
35 # Create data splits
36 indices_train, indices_test = data_handler.get_train_test_indices()
37
38 # Get dataloaders
39 train_loader = data_handler.get_dataloader(dataset, indices_train, shuffle=True)
40 test_loader = data_handler.get_dataloader(dataset, indices_test, shuffle=False)
41
42 logging.info("Defining the model")
43
44 model_params = {
45 "n_kernel": 5,
46 }
47 train_params = {
48 "epochs": 100,
49 "learning_rate": 0.001,
50 "momentum": 0.9,
51 }
52 model = Net(**model_params)
53
54 logging.info("Training the model")
55 train(model, train_loader, **train_params)
56
57 logging.info("Testing the model")
58 test(model, test_loader, data_handler.classes)
59
60 #############################################################################
61 # Below shows the use of the Target class to help generate the target_dir/
62 # If you have already saved your model, you can use the CLI target generator.
63 #############################################################################
64
65 logging.info("Wrapping the model and data in a Target object")
66 target = Target(
67 model=model,
68 model_module_path="model.py",
69 model_params=model_params, # Must match all required in model constructor
70 train_module_path="train.py",
71 train_params=train_params, # Must match all required in the train function
72 dataset_module_path="dataset.py",
73 dataset_name="Cifar10", # Must match the class name in dataset module
74 indices_train=indices_train,
75 indices_test=indices_test,
76 )
77
78 logging.info("Writing Target object to directory: '%s'", target_dir)
79 target.save(target_dir)
- Note: the training script above is included from
examples/pytorch/cifar/train_pytorch.py; dataset handling and model definitions referenced below are defined in their respective files. - caption:
CIFAR Dataset PyTorch Training (from examples/pytorch/cifar/train_pytorch.py)
Model Architecture:
1"""An example Pytorch classifier."""
2
3import torch
4from torch import nn
5
6
7class Net(nn.Module):
8 """A Pytorch classification model for cifar10."""
9
10 def __init__(self, n_kernel: int = 5):
11 super().__init__()
12 self.conv1 = nn.Conv2d(3, 6, n_kernel)
13 self.pool = nn.MaxPool2d(2, 2)
14 self.conv2 = nn.Conv2d(6, 16, n_kernel)
15 self.fc1 = nn.Linear(16 * n_kernel * n_kernel, 120)
16 self.fc2 = nn.Linear(120, 84)
17 self.fc3 = nn.Linear(84, 10)
18
19 def forward(self, x):
20 """Forward propagate input."""
21 x = self.pool(torch.relu(self.conv1(x)))
22 x = self.pool(torch.relu(self.conv2(x)))
23 x = torch.flatten(x, 1)
24 x = torch.relu(self.fc1(x))
25 x = torch.relu(self.fc2(x))
26 return self.fc3(x)
Note: this model architecture is included from examples/pytorch/cifar/model.py and contains the network and related definitions used by the training script.
Dataset Processing:
1"""Example dataset handler for CIFAR10.
2
3PyTorch datasets must implement `sacroml.attacks.data.PyTorchDataHandler`.
4"""
5
6from collections.abc import Sequence
7
8from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset
9from torchvision import transforms
10from torchvision.datasets import CIFAR10
11
12from sacroml.attacks.data import PyTorchDataHandler
13
14
15class Cifar10(PyTorchDataHandler):
16 """CIFAR10 dataset handler."""
17
18 def __init__(self) -> None:
19 """Fetch and process CIFAR10."""
20 self.transform = transforms.Compose(
21 [
22 transforms.ToTensor(),
23 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
24 ]
25 )
26
27 train_set = CIFAR10(
28 root="./data", train=True, download=True, transform=self.transform
29 )
30
31 test_set = CIFAR10(
32 root="./data", train=False, download=True, transform=self.transform
33 )
34
35 self.dataset = ConcatDataset([train_set, test_set])
36
37 self.classes = (
38 "plane",
39 "car",
40 "bird",
41 "cat",
42 "deer",
43 "dog",
44 "frog",
45 "horse",
46 "ship",
47 "truck",
48 )
49
50 def __len__(self) -> int:
51 """Return the length of the dataset."""
52 return len(self.dataset)
53
54 def get_raw_dataset(self) -> Dataset | None:
55 """Return a raw unprocessed dataset."""
56 # Raw data only required for attribute inference
57 return None
58
59 def get_dataset(self) -> Dataset:
60 """Return a preprocessed dataset."""
61 return self.dataset
62
63 def get_dataloader(
64 self,
65 dataset: Dataset,
66 indices: Sequence[int],
67 batch_size: int = 32,
68 shuffle: bool = False,
69 ) -> DataLoader:
70 """Return a data loader with a requested subset of samples."""
71 subset = Subset(dataset, indices)
72 return DataLoader(subset, batch_size=batch_size, shuffle=shuffle)
73
74 def get_train_test_indices(self) -> tuple[Sequence[int], Sequence[int]]:
75 """Return train and test set indices."""
76 train = range(50000)
77 test = range(50000, 60000)
78 return train, test
Note: dataset loading and preprocessing functions are provided in examples/pytorch/cifar/dataset.py; training and evaluation snippets reference these utilities.
Running Privacy Attacks:
1"""Example of how to run attacks on a model saved with the Target wrapper."""
2
3import logging
4
5from sacroml.attacks.likelihood_attack import LIRAAttack
6from sacroml.attacks.target import Target
7from sacroml.attacks.worst_case_attack import WorstCaseAttack
8
9output_dir = "output_pytorch"
10target_dir = "target_pytorch"
11
12
13if __name__ == "__main__":
14 logging.info("Loading Target object from '%s'", target_dir)
15
16 target = Target()
17 target.load(target_dir)
18
19 logging.info("Running attacks...")
20
21 attack = WorstCaseAttack(n_reps=10, output_dir=output_dir)
22 attack.attack(target)
23
24 attack = LIRAAttack(n_shadow_models=40, output_dir=output_dir)
25 attack.attack(target)
Note: the attack code is taken from examples/pytorch/cifar/attack_pytorch.py and may depend on the model and dataset code linked above.