diff --git a/examples/attack/targeted/rbcd_attack.py b/examples/attack/targeted/rbcd_attack.py new file mode 100644 index 0000000..3b443ea --- /dev/null +++ b/examples/attack/targeted/rbcd_attack.py @@ -0,0 +1,93 @@ +import os.path as osp + +import torch +import torch_geometric.transforms as T + +from greatx.attack.targeted import GRBCDAttack, PRBCDAttack +from greatx.datasets import GraphDataset +from greatx.nn.models import GCN +from greatx.training import Trainer +from greatx.training.callbacks import ModelCheckpoint +from greatx.utils import mark, split_nodes + +dataset = 'Cora' +root = osp.join(osp.dirname(osp.realpath(__file__)), '../../..', 'data') +dataset = GraphDataset(root=root, name=dataset, + transform=T.LargestConnectedComponents()) + +data = dataset[0] +splits = split_nodes(data.y, random_state=15) + +num_features = data.x.size(-1) +num_classes = data.y.max().item() + 1 + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# ================================================================== # +# Attack Setting # +# ================================================================== # +target = 1 # target node to attack +target_label = data.y[target].item() + +# ================================================================== # +# Before Attack # +# ================================================================== # +trainer_before = Trainer(GCN(num_features, num_classes), device=device) +ckp = ModelCheckpoint('model_before.pth', monitor='val_acc') +trainer_before.fit(data, mask=(splits.train_nodes, splits.val_nodes), + callbacks=[ckp]) +output = trainer_before.predict(data, mask=target) +print("Before attack:") +print(mark(output, target_label)) + +# ================================================================== # +# Attacking (PRBCDAttack) # +# ================================================================== # +attacker = PRBCDAttack(data, device=device) +attacker.setup_surrogate(trainer_before.model) +attacker.reset() +attacker.attack(target) + +# ================================================================== # +# After evasion Attack # +# ================================================================== # +output = trainer_before.predict(attacker.data(), mask=target) +print("After evasion attack:") +print(mark(output, target_label)) + +# ================================================================== # +# After poisoning Attack # +# ================================================================== # +trainer_after = Trainer(GCN(num_features, num_classes), device=device) +ckp = ModelCheckpoint('model_after.pth', monitor='val_acc') +trainer_after.fit(attacker.data(), mask=(splits.train_nodes, splits.val_nodes), + callbacks=[ckp]) +output = trainer_after.predict(attacker.data(), mask=target) +print("After poisoning attack:") +print(mark(output, target_label)) + +# ================================================================== # +# Attacking (GRBCDAttack) # +# ================================================================== # +attacker = GRBCDAttack(data, device=device) +attacker.setup_surrogate(trainer_before.model) +attacker.reset() +attacker.attack(target) + +# ================================================================== # +# After evasion Attack # +# ================================================================== # +output = trainer_before.predict(attacker.data(), mask=target) +print("After evasion attack:") +print(mark(output, target_label)) + +# ================================================================== # +# After poisoning Attack # +# ================================================================== # +trainer_after = Trainer(GCN(num_features, num_classes), device=device) +ckp = ModelCheckpoint('model_after.pth', monitor='val_acc') +trainer_after.fit(attacker.data(), mask=(splits.train_nodes, splits.val_nodes), + callbacks=[ckp]) +output = trainer_after.predict(attacker.data(), mask=target) +print("After poisoning attack:") +print(mark(output, target_label)) diff --git a/examples/attack/untargeted/rbcd_attack.py b/examples/attack/untargeted/rbcd_attack.py new file mode 100644 index 0000000..f6388e5 --- /dev/null +++ b/examples/attack/untargeted/rbcd_attack.py @@ -0,0 +1,90 @@ +import os.path as osp + +import torch +import torch_geometric.transforms as T + +from greatx.attack.untargeted import GRBCDAttack, PRBCDAttack +from greatx.datasets import GraphDataset +from greatx.nn.models import GCN +from greatx.training import Trainer +from greatx.training.callbacks import ModelCheckpoint +from greatx.utils import split_nodes + +dataset = 'Cora' +root = osp.join(osp.dirname(osp.realpath(__file__)), '../../..', 'data') +dataset = GraphDataset(root=root, name=dataset, + transform=T.LargestConnectedComponents()) + +data = dataset[0] +splits = split_nodes(data.y, random_state=15) + +num_features = data.x.size(-1) +num_classes = data.y.max().item() + 1 + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# ================================================================== # +# Before Attack # +# ================================================================== # +trainer_before = Trainer(GCN(num_features, num_classes), device=device) +ckp = ModelCheckpoint('model_before.pth', monitor='val_acc') +trainer_before.fit(data, mask=(splits.train_nodes, splits.val_nodes), + callbacks=[ckp]) +logs = trainer_before.evaluate(data, splits.test_nodes) +print(f"Before attack\n {logs}") + +# ================================================================== # +# Attacking (PRBCDAttack) # +# ================================================================== # +attacker = PRBCDAttack(data, device=device) +attacker.setup_surrogate( + trainer_before.model, + victim_nodes=splits.test_nodes, + # set True to use ground-truth labels + ground_truth=False, +) +attacker.reset() +attacker.attack(0.05) + +# ================================================================== # +# After evasion Attack # +# ================================================================== # +logs = trainer_before.evaluate(attacker.data(), splits.test_nodes) +print(f"After evasion attack\n {logs}") +# ================================================================== # +# After poisoning Attack # +# ================================================================== # +trainer_after = Trainer(GCN(num_features, num_classes), device=device) +ckp = ModelCheckpoint('model_after.pth', monitor='val_acc') +trainer_after.fit(attacker.data(), mask=(splits.train_nodes, splits.val_nodes), + callbacks=[ckp]) +logs = trainer_after.evaluate(attacker.data(), splits.test_nodes) +print(f"After poisoning attack\n {logs}") + +# ================================================================== # +# Attacking (GRBCDAttack) # +# ================================================================== # +attacker = GRBCDAttack(data, device=device) +attacker.setup_surrogate( + trainer_before.model, + victim_nodes=splits.test_nodes, + # set True to use ground-truth labels + ground_truth=False, +) +attacker.reset() +attacker.attack(0.05) + +# ================================================================== # +# After evasion Attack # +# ================================================================== # +logs = trainer_before.evaluate(attacker.data(), splits.test_nodes) +print(f"After evasion attack\n {logs}") +# ================================================================== # +# After poisoning Attack # +# ================================================================== # +trainer_after = Trainer(GCN(num_features, num_classes), device=device) +ckp = ModelCheckpoint('model_after.pth', monitor='val_acc') +trainer_after.fit(attacker.data(), mask=(splits.train_nodes, splits.val_nodes), + callbacks=[ckp]) +logs = trainer_after.evaluate(attacker.data(), splits.test_nodes) +print(f"After poisoning attack\n {logs}") diff --git a/greatx/attack/targeted/__init__.py b/greatx/attack/targeted/__init__.py index 4204717..31c6a78 100644 --- a/greatx/attack/targeted/__init__.py +++ b/greatx/attack/targeted/__init__.py @@ -1,3 +1,4 @@ +from .targeted_attacker import TargetedAttacker from .dice_attack import DICEAttack from .fg_attack import FGAttack from .gf_attack import GFAttack @@ -6,7 +7,7 @@ from .pgd_attack import PGDAttack from .random_attack import RandomAttack from .sg_attack import SGAttack -from .targeted_attacker import TargetedAttacker +from .rbcd_attack import PRBCDAttack, GRBCDAttack classes = __all__ = [ 'TargetedAttacker', @@ -18,4 +19,6 @@ 'Nettack', 'GFAttack', 'PGDAttack', + 'PRBCDAttack', + 'GRBCDAttack', ] diff --git a/greatx/attack/targeted/rbcd_attack.py b/greatx/attack/targeted/rbcd_attack.py new file mode 100644 index 0000000..224aa41 --- /dev/null +++ b/greatx/attack/targeted/rbcd_attack.py @@ -0,0 +1,296 @@ +from collections import defaultdict +from typing import Any, Callable, Dict, Iterable, List, Optional, Union + +import torch +from torch import Tensor +from torch_geometric.utils import coalesce, to_undirected + +from greatx.attack.targeted.targeted_attacker import TargetedAttacker +from greatx.attack.untargeted.rbcd_attack import RBCDAttack +from greatx.attack.untargeted.utils import project +from greatx.nn.models.surrogate import Surrogate + +# (predictions, labels, ids/mask) -> Tensor with one element +METRIC = Callable[[Tensor, Tensor, Optional[Tensor]], Tensor] + + +class PRBCDAttack(TargetedAttacker, RBCDAttack, Surrogate): + r"""Projected Randomized Block Coordinate Descent (PRBCD) adversarial + attack from the `Robustness of Graph Neural Networks at Scale + `_ paper. + + This attack uses an efficient gradient based approach that (during the + attack) relaxes the discrete entries in the adjacency matrix + :math:`\{0, 1\}` to :math:`[0, 1]` and solely perturbs the adjacency matrix + (no feature perturbations). Thus, this attack supports all models that can + handle weighted graphs that are differentiable w.r.t. these edge weights. + For non-differentiable models you might be able to e.g. use the gumble + softmax trick. + + The memory overhead is driven by the additional edges (at most + :attr:`block_size`). For scalability reasons, the block is drawn with + replacement and then the index is made unique. Thus, the actual block size + is typically slightly smaller than specified. + + This attack can be used for both global and local attacks as well as + test-time attacks (evasion) and training-time attacks (poisoning). Please + see the provided examples. + + This attack is designed with a focus on node- or graph-classification, + however, to adapt to other tasks you most likely only need to provide an + appropriate loss and model. However, we currently do not support batching + out of the box (sampling needs to be adapted). + + """ + def reset(self) -> "PRBCDAttack": + super().reset() + self.current_block = None + self.block_edge_index = None + self.block_edge_weight = None + self.loss = None + self.metric = None + + self.victim_nodes = None + self.victim_labels = None + + # NOTE: Since `edge_index` and `edge_weight` denote the original graph + # here we need to name them as `_edge_index`and `_edge_weight` + self._edge_index = self.edge_index + self._edge_weight = torch.ones(self.num_edges, device=self.device) + + # For early stopping (not explicitly covered by pseudo code) + self.best_metric = float('-Inf') + + # For collecting attack statistics + self.attack_statistics = defaultdict(list) + + return self + + def attack( + self, + target, + *, + target_label=None, + num_budgets=None, + direct_attack=True, + block_size: int = 250_000, + epochs: int = 125, + epochs_resampling: int = 100, + loss: Optional[str] = 'tanh_margin', + metric: Optional[Union[str, METRIC]] = None, + lr: float = 2_000, + structure_attack: bool = True, + feature_attack: bool = False, + disable: bool = False, + **kwargs, + ) -> "PRBCDAttack": + + super().attack(target, target_label, num_budgets=num_budgets, + direct_attack=direct_attack, + structure_attack=structure_attack, + feature_attack=feature_attack) + self.victim_nodes = torch.as_tensor( + target, + dtype=torch.long, + device=self.device, + ).view(-1) + + self.victim_labels = self.target_label.view(-1) + + return RBCDAttack.attack(self, block_size=block_size, epochs=epochs, + epochs_resampling=epochs_resampling, + loss=loss, metric=metric, lr=lr, + disable=disable, **kwargs) + + def prepare(self, num_budgets: int, epochs: int) -> Iterable[int]: + """Prepare attack and return the iterable sequence steps.""" + + # Sample initial search space (Algorithm 1, line 3-4) + self.sample_random_block(num_budgets) + + return range(epochs) + + @torch.no_grad() + def update(self, epoch: int, gradient: Tensor) -> Dict[str, float]: + """Update edge weights given gradient.""" + # Gradient update step (Algorithm 1, line 7) + self.update_edge_weights(epoch, gradient) + + # For monitoring + pmass_update = torch.clamp(self.block_edge_weight, 0, 1) + # Projection to stay within relaxed `L_0` num_budgets + # (Algorithm 1, line 8) + self.block_edge_weight = project(self.num_budgets, + self.block_edge_weight, + self.coeffs['eps']) + + # For monitoring + scalars = dict( + prob_mass_after_update=pmass_update.sum().item(), + prob_mass_after_update_max=pmass_update.max().item(), + prob_mass_afterprojection=self.block_edge_weight.sum().item(), + prob_mass_afterprojection_nonzero_weights=( + self.block_edge_weight > self.coeffs['eps']).sum().item(), + prob_mass_afterprojection_max=self.block_edge_weight.max().item(), + ) + + if not self.coeffs['with_early_stopping']: + return scalars + + # Calculate metric after the current epoch (overhead + # for monitoring and early stopping) + + topk_block_edge_weight = torch.zeros_like(self.block_edge_weight) + topk_block_edge_weight[torch.topk(self.block_edge_weight, + self.num_budgets).indices] = 1 + + edge_index, edge_weight = self.get_modified_graph( + self._edge_index, self._edge_weight, self.block_edge_index, + topk_block_edge_weight) + + prediction = self.surrogate(self.feat, edge_index, + edge_weight)[self.victim_nodes] + metric = self.metric(prediction, self.victim_labels) + + # Save best epoch for early stopping + # (not explicitly covered by pseudo code) + if metric > self.best_metric: + self.best_metric = metric + self.best_block = self.current_block.cpu().clone() + self.best_edge_index = self.block_edge_index.cpu().clone() + self.best_pert_edge_weight = self.block_edge_weight.cpu().detach() + + # Resampling of search space (Algorithm 1, line 9-14) + if epoch < self.epochs_resampling - 1: + self.resample_random_block(self.num_budgets) + elif epoch == self.epochs_resampling - 1: + # Retrieve best epoch if early stopping is active + # (not explicitly covered by pseudo code) + self.current_block = self.best_block.to(self.device) + self.block_edge_index = self.best_edge_index.to(self.device) + block_edge_weight = self.best_pert_edge_weight.clone() + self.block_edge_weight = block_edge_weight.to(self.device) + + scalars['metric'] = metric.item() + return scalars + + def get_flipped_edges(self) -> Tensor: + """Clean up and prepare return flipped edges.""" + + # Retrieve best epoch if early stopping is active + # (not explicitly covered by pseudo code) + if self.coeffs['with_early_stopping']: + self.current_block = self.best_block.to(self.device) + self.block_edge_index = self.best_edge_index.to(self.device) + self.block_edge_weight = self.best_pert_edge_weight.to(self.device) + + # Sample final discrete graph (Algorithm 1, line 16) + return self.sample_final_edges() + + +class GRBCDAttack(PRBCDAttack): + r"""Greedy Randomized Block Coordinate Descent (GRBCD) adversarial attack + from the `Robustness of Graph Neural Networks at Scale + `_ paper. + + GRBCD shares most of the properties and requirements with + :class:`PRBCDAttack`. It also uses an efficient gradient based approach. + However, it greedily flips edges based on the gradient towards the + adjacency matrix. + """ + def attack( + self, + target, + *, + target_label=None, + num_budgets=None, + direct_attack=True, + block_size: int = 250_000, + epochs: int = 125, + epochs_resampling: int = 100, + loss: Optional[str] = 'mce', + metric: Optional[Union[str, METRIC]] = None, + lr: float = 1_000, + structure_attack: bool = True, + feature_attack: bool = False, + disable: bool = False, + **kwargs, + ) -> "GRBCDAttack": + + return super().attack(target=target, target_label=target_label, + direct_attack=direct_attack, + num_budgets=num_budgets, block_size=block_size, + epochs=epochs, + epochs_resampling=epochs_resampling, + metric=metric, loss=loss, lr=lr, disable=disable, + structure_attack=structure_attack, + feature_attack=feature_attack, **kwargs) + + def prepare(self, num_budgets: int, epochs: int) -> List[int]: + """Prepare attack.""" + + # Determine the number of edges to be flipped in each attach step/epoch + step_size = num_budgets // epochs + if step_size > 0: + steps = epochs * [step_size] + for i in range(num_budgets % epochs): + steps[i] += 1 + else: + steps = [1] * num_budgets + + # Sample initial search space (Algorithm 2, line 3-4) + self.sample_random_block(step_size) + + return steps + + def reset(self) -> "GRBCDAttack": + super().reset() + self.flipped_edges = self._edge_index.new_empty(2, 0) + return self + + @torch.no_grad() + def update( + self, + step_size: int, + gradient: Tensor, + ) -> Dict[str, Any]: + """Update edge weights given gradient.""" + _, topk_edge_index = torch.topk(gradient, step_size) + + flip_edge_index = self.block_edge_index[:, topk_edge_index].to( + self.device) + flip_edge_weight = torch.ones(flip_edge_index.size(1), + device=self.device) + + self.flipped_edges = torch.cat((self.flipped_edges, flip_edge_index), + axis=-1) + + if self.is_undirected: + flip_edge_index, flip_edge_weight = to_undirected( + flip_edge_index, flip_edge_weight, num_nodes=self.num_nodes, + reduce='mean') + + edge_index = torch.cat((self._edge_index, flip_edge_index), dim=-1) + edge_weight = torch.cat((self._edge_weight, flip_edge_weight)) + + edge_index, edge_weight = coalesce(edge_index, edge_weight, + num_nodes=self.num_nodes, + reduce='sum') + + mask = torch.isclose(edge_weight, torch.tensor(1.)) + + self._edge_index = edge_index[:, mask] + self._edge_weight = edge_weight[mask] + + # Sample initial search space (Algorithm 2, line 3-4) + self.sample_random_block(step_size) + + # Return debug information + scalars = { + 'number_positive_entries_in_gradient': (gradient > 0).sum().item() + } + return scalars + + def get_flipped_edges(self) -> Tensor: + """Clean up and prepare return flipped edges.""" + return self.flipped_edges diff --git a/greatx/attack/untargeted/__init__.py b/greatx/attack/untargeted/__init__.py index a7eefa2..c74fb2e 100644 --- a/greatx/attack/untargeted/__init__.py +++ b/greatx/attack/untargeted/__init__.py @@ -5,8 +5,16 @@ from .pgd_attack import PGDAttack from .random_attack import RandomAttack from .untargeted_attacker import UntargetedAttacker +from .rbcd_attack import PRBCDAttack, GRBCDAttack classes = __all__ = [ - 'UntargetedAttacker', 'RandomAttack', 'DICEAttack', 'FGAttack', 'IGAttack', - 'Metattack', 'PGDAttack' + 'UntargetedAttacker', + 'RandomAttack', + 'DICEAttack', + 'FGAttack', + 'IGAttack', + 'Metattack', + 'PGDAttack', + 'PRBCDAttack', + 'GRBCDAttack', ] diff --git a/greatx/attack/untargeted/rbcd_attack.py b/greatx/attack/untargeted/rbcd_attack.py new file mode 100644 index 0000000..1a82ab3 --- /dev/null +++ b/greatx/attack/untargeted/rbcd_attack.py @@ -0,0 +1,602 @@ +from collections import defaultdict +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import Tensor +from torch_geometric.utils import coalesce, to_undirected +from tqdm.auto import tqdm + +from greatx.attack.untargeted.untargeted_attacker import UntargetedAttacker +from greatx.attack.untargeted.utils import ( + linear_to_full_idx, + linear_to_triu_idx, + num_possible_edges, + project, +) +from greatx.functional import ( + masked_cross_entropy, + probability_margin_loss, + tanh_margin_loss, +) +from greatx.nn.models.surrogate import Surrogate + +# (predictions, labels, ids/mask) -> Tensor with one element +METRIC = Callable[[Tensor, Tensor, Optional[Tensor]], Tensor] + + +class RBCDAttack: + """Base class for :class:`PRBCDAttack` and + :class:`GRBCDEAttack`.""" + + # RBCDAttack will not ensure there are no singleton nodes + _allow_singleton: bool = False + + # TODO: Although RBCDAttack accepts directed graphs, + # we currently don't explicitlyt support directed graphs. + # This should be made available in the future. + is_undirected: bool = True + + coeffs: Dict[str, Any] = { + 'max_final_samples': 20, + 'max_trials_sampling': 20, + 'with_early_stopping': True, + 'eps': 1e-7 + } + + def attack( + self, + block_size: int = 250_000, + epochs: int = 125, + epochs_resampling: int = 100, + loss: Optional[str] = 'tanh_margin', + metric: Optional[Union[str, METRIC]] = None, + lr: float = 2_000, + disable: bool = False, + **kwargs, + ) -> "RBCDAttack": + + self.block_size = block_size + + assert loss in ['mce', 'prob_margin', 'tanh_margin'] + if loss == 'mce': + self.loss = masked_cross_entropy + elif loss == 'prob_margin': + self.loss = probability_margin_loss + else: + self.loss = tanh_margin_loss + + if metric is None: + self.metric = self.loss + else: + self.metric = metric + + self.epochs_resampling = epochs_resampling + self.lr = lr + + self.coeffs.update(**kwargs) + + # Loop over the epochs (Algorithm 1, line 5) + for step in tqdm(self.prepare(self.num_budgets, epochs), + desc='Peturbing graph...', disable=disable): + + loss, gradient = self.compute_gradients() + + scalars = self.update(step, gradient) + + scalars['loss'] = loss.item() + self._append_statistics(scalars) + + flipped_edges = self.get_flipped_edges() + + assert flipped_edges.size(1) <= self.num_budgets, ( + f'# perturbed edges {flipped_edges.size(1)} ' + f'exceeds num_budgets {self.num_budgets}') + + for it, (u, v) in enumerate(zip(*flipped_edges.tolist())): + if self.adjacency_matrix[u, v] > 0: + self.remove_edge(u, v, it) + else: + self.add_edge(u, v, it) + + return self + + def compute_gradients(self) -> Tuple[Tensor, Tensor]: + """Forward and update edge weights.""" + self.block_edge_weight.requires_grad_() + + # Retrieve sparse perturbed adjacency matrix `A \oplus p_{t-1}` + # (Algorithm 1, line 6 / Algorithm 2, line 7) + edge_index, edge_weight = self.get_modified_graph( + self._edge_index, self._edge_weight, self.block_edge_index, + self.block_edge_weight) + + # Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7) + prediction = self.surrogate(self.feat, edge_index, + edge_weight)[self.victim_nodes] + + # temperature scaling, work for cross-entropy loss + if self.tau != 1: + prediction /= self.tau + + # Calculate loss combining all each node + # (Algorithm 1, line 7 / Algorithm 2, line 8) + loss = self.loss(prediction, self.victim_labels) + # Retrieve gradient towards the current block + # (Algorithm 1, line 7 / Algorithm 2, line 8) + gradient = torch.autograd.grad(loss, self.block_edge_weight)[0] + + return loss, gradient + + def get_modified_graph( + self, + edge_index: Tensor, + edge_weight: Tensor, + block_edge_index: Tensor, + block_edge_weight: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Merges adjacency matrix with current block (incl. weights)""" + if self.is_undirected: + block_edge_index, block_edge_weight = to_undirected( + block_edge_index, block_edge_weight, num_nodes=self.num_nodes, + reduce='mean') + + modified_edge_index = torch.cat((edge_index, block_edge_index), dim=-1) + modified_edge_weight = torch.cat((edge_weight, block_edge_weight)) + + modified_edge_index, modified_edge_weight = coalesce( + modified_edge_index, modified_edge_weight, + num_nodes=self.num_nodes, reduce='sum') + + # Allow (soft) removal of edges + mask = modified_edge_weight > 1 + modified_edge_weight[mask] = 2 - modified_edge_weight[mask] + + return modified_edge_index, modified_edge_weight + + @torch.no_grad() + def sample_random_block(self, num_budgets: int = 0): + for _ in range(self.coeffs['max_trials_sampling']): + num_possible = num_possible_edges(self.num_nodes, + self.is_undirected) + self.current_block = torch.randint(num_possible, + (self.block_size, ), + device=self.device) + self.current_block = torch.unique(self.current_block, sorted=True) + + if self.is_undirected: + self.block_edge_index = linear_to_triu_idx( + self.num_nodes, self.current_block) + else: + self.block_edge_index = linear_to_full_idx( + self.num_nodes, self.current_block) + + self._filter_self_loops_in_block(with_weight=False) + + self.block_edge_weight = torch.full(self.current_block.shape, + self.coeffs['eps'], + device=self.device) + if self.current_block.size(0) >= num_budgets: + return + + raise RuntimeError("Sampling random block was not successful. " + "Please decrease `num_budgets`.") + + def resample_random_block(self, num_budgets: int): + # Keep at most half of the block (i.e. resample low weights) + sorted_idx = torch.argsort(self.block_edge_weight) + keep_above = (self.block_edge_weight <= + self.coeffs['eps']).sum().long() + if keep_above < sorted_idx.size(0) // 2: + keep_above = sorted_idx.size(0) // 2 + sorted_idx = sorted_idx[keep_above:] + + self.current_block = self.current_block[sorted_idx] + + # Sample until enough edges were drawn + for _ in range(self.coeffs['max_trials_sampling']): + n_edges_resample = self.block_size - self.current_block.size(0) + num_possible = num_possible_edges(self.num_nodes, + self.is_undirected) + lin_index = torch.randint(num_possible, (n_edges_resample, ), + device=self.device) + + current_block = torch.cat((self.current_block, lin_index)) + self.current_block, unique_idx = torch.unique( + current_block, sorted=True, return_inverse=True) + + if self.is_undirected: + self.block_edge_index = linear_to_triu_idx( + self.num_nodes, self.current_block) + else: + self.block_edge_index = linear_to_full_idx( + self.num_nodes, self.current_block) + + # Merge existing weights with new edge weights + block_edge_weight_prev = self.block_edge_weight[sorted_idx] + self.block_edge_weight = torch.full(self.current_block.shape, + self.coeffs['eps'], + device=self.device) + + self.block_edge_weight[ + unique_idx[:sorted_idx.size(0)]] = block_edge_weight_prev + + if not self.is_undirected: + self._filter_self_loops_in_block(with_weight=True) + + if self.current_block.size(0) > num_budgets: + return + + raise RuntimeError("Sampling random block was not successful." + "Please decrease `num_budgets`.") + + @torch.no_grad() + def sample_final_edges(self) -> Tuple[Tensor, Tensor]: + best_metric = float('-Inf') + block_edge_weight = self.block_edge_weight + block_edge_weight[block_edge_weight <= self.coeffs['eps']] = 0 + num_budgets = self.num_budgets + feat = self.feat + victim_nodes = self.victim_nodes + victim_labels = self.victim_labels + + for i in range(self.coeffs['max_final_samples']): + if i == 0: + # In first iteration employ top k heuristic instead of sampling + sampled_edges = torch.zeros_like(block_edge_weight) + sampled_edges[torch.topk(block_edge_weight, + num_budgets).indices] = 1 + else: + sampled_edges = torch.bernoulli(block_edge_weight).float() + + if sampled_edges.sum() > num_budgets: + # Allowed num_budgets is exceeded + continue + + self.block_edge_weight = sampled_edges + + edge_index, edge_weight = self.get_modified_graph( + self._edge_index, self._edge_weight, self.block_edge_index, + self.block_edge_weight) + prediction = self.surrogate(feat, edge_index, + edge_weight)[victim_nodes] + metric = self.metric(prediction, victim_labels) + + # Save best sample + if metric > best_metric: + best_metric = metric + best_edge_weight = self.block_edge_weight.clone().cpu() + + flipped_edges = self.block_edge_index[:, best_edge_weight != 0] + return flipped_edges + + def update_edge_weights(self, epoch: int, gradient: Tensor): + # The learning rate is refined heuristically, s.t. (1) it is + # independent of the number of perturbations (assuming an undirected + # adjacency matrix) and (2) to decay learning rate during fine-tuning + # (i.e. fixed search space). + lr = (self.num_budgets / self.num_nodes * self.lr / + np.sqrt(max(0, epoch - self.epochs_resampling) + 1)) + self.block_edge_weight.data.add_(lr * gradient) + + def _filter_self_loops_in_block(self, with_weight: bool): + mask = self.block_edge_index[0] != self.block_edge_index[1] + self.current_block = self.current_block[mask] + self.block_edge_index = self.block_edge_index[:, mask] + if with_weight: + self.block_edge_weight = self.block_edge_weight[mask] + + def _append_statistics(self, mapping: Dict[str, Any]): + for key, value in mapping.items(): + self.attack_statistics[key].append(value) + + +class PRBCDAttack(UntargetedAttacker, RBCDAttack, Surrogate): + r"""Projected Randomized Block Coordinate Descent (PRBCD) adversarial + attack from the `Robustness of Graph Neural Networks at Scale + `_ paper. + + This attack uses an efficient gradient based approach that (during the + attack) relaxes the discrete entries in the adjacency matrix + :math:`\{0, 1\}` to :math:`[0, 1]` and solely perturbs the adjacency matrix + (no feature perturbations). Thus, this attack supports all models that can + handle weighted graphs that are differentiable w.r.t. these edge weights. + For non-differentiable models you might be able to e.g. use the gumble + softmax trick. + + The memory overhead is driven by the additional edges (at most + :attr:`block_size`). For scalability reasons, the block is drawn with + replacement and then the index is made unique. Thus, the actual block size + is typically slightly smaller than specified. + + This attack can be used for both global and local attacks as well as + test-time attacks (evasion) and training-time attacks (poisoning). Please + see the provided examples. + + This attack is designed with a focus on node- or graph-classification, + however, to adapt to other tasks you most likely only need to provide an + appropriate loss and model. However, we currently do not support batching + out of the box (sampling needs to be adapted). + + """ + def setup_surrogate( + self, + surrogate: torch.nn.Module, + victim_nodes: Tensor, + ground_truth: bool = False, + *, + tau: float = 1.0, + freeze: bool = True, + ) -> "PRBCDAttack": + r"""Setup the surrogate model for adversarial attack. + + Parameters + ---------- + surrogate : torch.nn.Module + the surrogate model + victim_nodes : Tensor + the victim nodes_set + ground_truth : bool, optional + whether to use ground-truth label for victim nodes, + if False, the node labels are estimated by the surrogate model, + by default False + tau : float, optional + the temperature of softmax activation, by default 1.0 + freeze : bool, optional + whether to free the surrogate model to avoid the + gradient accumulation, by default True + + Returns + ------- + PRBCDAttack + the attacker itself + """ + + Surrogate.setup_surrogate(self, surrogate=surrogate, tau=tau, + freeze=freeze) + + if victim_nodes.dtype == torch.bool: + victim_nodes = victim_nodes.nonzero().view(-1) + self.victim_nodes = victim_nodes.to(self.device) + + if ground_truth: + self.victim_labels = self.label[victim_nodes] + else: + self.victim_labels = self.estimate_self_training_labels( + victim_nodes) + + return self + + def reset(self) -> "PRBCDAttack": + super().reset() + self.current_block = None + self.block_edge_index = None + self.block_edge_weight = None + self.loss = None + self.metric = None + + # NOTE: Since `edge_index` and `edge_weight` denote the original graph + # here we need to name them as `edge_index`and `_edge_weight` + self._edge_index = self.edge_index + self._edge_weight = torch.ones(self.num_edges, device=self.device) + + # For early stopping (not explicitly covered by pseudo code) + self.best_metric = float('-Inf') + + # For collecting attack statistics + self.attack_statistics = defaultdict(list) + + return self + + def attack( + self, + num_budgets: Union[int, float] = 0.05, + *, + block_size: int = 250_000, + epochs: int = 125, + epochs_resampling: int = 100, + loss: Optional[str] = 'tanh_margin', + metric: Optional[Union[str, METRIC]] = None, + lr: float = 2_000, + structure_attack: bool = True, + feature_attack: bool = False, + disable: bool = False, + **kwargs, + ) -> "PRBCDAttack": + + super().attack(num_budgets=num_budgets, + structure_attack=structure_attack, + feature_attack=feature_attack) + + return RBCDAttack.attack(self, block_size=block_size, epochs=epochs, + epochs_resampling=epochs_resampling, + loss=loss, metric=metric, lr=lr, + disable=disable, **kwargs) + + def prepare(self, num_budgets: int, epochs: int) -> Iterable[int]: + """Prepare attack and return the iterable sequence steps.""" + + # Sample initial search space (Algorithm 1, line 3-4) + self.sample_random_block(num_budgets) + + return range(epochs) + + @torch.no_grad() + def update(self, epoch: int, gradient: Tensor) -> Dict[str, float]: + """Update edge weights given gradient.""" + # Gradient update step (Algorithm 1, line 7) + self.update_edge_weights(epoch, gradient) + + # For monitoring + pmass_update = torch.clamp(self.block_edge_weight, 0, 1) + # Projection to stay within relaxed `L_0` num_budgets + # (Algorithm 1, line 8) + self.block_edge_weight = project(self.num_budgets, + self.block_edge_weight, + self.coeffs['eps']) + + # For monitoring + scalars = dict( + prob_mass_after_update=pmass_update.sum().item(), + prob_mass_after_update_max=pmass_update.max().item(), + prob_mass_afterprojection=self.block_edge_weight.sum().item(), + prob_mass_afterprojection_nonzero_weights=( + self.block_edge_weight > self.coeffs['eps']).sum().item(), + prob_mass_afterprojection_max=self.block_edge_weight.max().item(), + ) + + if not self.coeffs['with_early_stopping']: + return scalars + + # Calculate metric after the current epoch (overhead + # for monitoring and early stopping) + + topk_block_edge_weight = torch.zeros_like(self.block_edge_weight) + topk_block_edge_weight[torch.topk(self.block_edge_weight, + self.num_budgets).indices] = 1 + + edge_index, edge_weight = self.get_modified_graph( + self._edge_index, self._edge_weight, self.block_edge_index, + topk_block_edge_weight) + + prediction = self.surrogate(self.feat, edge_index, + edge_weight)[self.victim_nodes] + metric = self.metric(prediction, self.victim_labels) + + # Save best epoch for early stopping + # (not explicitly covered by pseudo code) + if metric > self.best_metric: + self.best_metric = metric + self.best_block = self.current_block.cpu().clone() + self.best_edge_index = self.block_edge_index.cpu().clone() + self.best_pert_edge_weight = self.block_edge_weight.cpu().detach() + + # Resampling of search space (Algorithm 1, line 9-14) + if epoch < self.epochs_resampling - 1: + self.resample_random_block(self.num_budgets) + elif epoch == self.epochs_resampling - 1: + # Retrieve best epoch if early stopping is active + # (not explicitly covered by pseudo code) + self.current_block = self.best_block.to(self.device) + self.block_edge_index = self.best_edge_index.to(self.device) + block_edge_weight = self.best_pert_edge_weight.clone() + self.block_edge_weight = block_edge_weight.to(self.device) + + scalars['metric'] = metric.item() + return scalars + + def get_flipped_edges(self) -> Tensor: + """Clean up and prepare return flipped edges.""" + + # Retrieve best epoch if early stopping is active + # (not explicitly covered by pseudo code) + if self.coeffs['with_early_stopping']: + self.current_block = self.best_block.to(self.device) + self.block_edge_index = self.best_edge_index.to(self.device) + self.block_edge_weight = self.best_pert_edge_weight.to(self.device) + + # Sample final discrete graph (Algorithm 1, line 16) + return self.sample_final_edges() + + +class GRBCDAttack(PRBCDAttack): + r"""Greedy Randomized Block Coordinate Descent (GRBCD) adversarial attack + from the `Robustness of Graph Neural Networks at Scale + `_ paper. + + GRBCD shares most of the properties and requirements with + :class:`PRBCDAttack`. It also uses an efficient gradient based approach. + However, it greedily flips edges based on the gradient towards the + adjacency matrix. + """ + def attack( + self, + num_budgets: Union[int, float] = 0.05, + *, + block_size: int = 250_000, + epochs: int = 125, + epochs_resampling: int = 100, + loss: Optional[str] = 'mce', + metric: Optional[Union[str, METRIC]] = None, + lr: float = 2_000, + structure_attack: bool = True, + feature_attack: bool = False, + disable: bool = False, + **kwargs, + ) -> "GRBCDAttack": + + return super().attack(num_budgets=num_budgets, block_size=block_size, + epochs=epochs, + epochs_resampling=epochs_resampling, + metric=metric, loss=loss, lr=lr, disable=disable, + structure_attack=structure_attack, + feature_attack=feature_attack, **kwargs) + + def prepare(self, num_budgets: int, epochs: int) -> List[int]: + """Prepare attack.""" + + # Determine the number of edges to be flipped in each attach step/epoch + step_size = num_budgets // epochs + if step_size > 0: + steps = epochs * [step_size] + for i in range(num_budgets % epochs): + steps[i] += 1 + else: + steps = [1] * num_budgets + + # Sample initial search space (Algorithm 2, line 3-4) + self.sample_random_block(step_size) + + return steps + + def reset(self) -> "GRBCDAttack": + super().reset() + self.flipped_edges = self._edge_index.new_empty(2, 0) + return self + + @torch.no_grad() + def update( + self, + step_size: int, + gradient: Tensor, + ) -> Dict[str, Any]: + """Update edge weights given gradient.""" + _, topk_edge_index = torch.topk(gradient, step_size) + + flip_edge_index = self.block_edge_index[:, topk_edge_index].to( + self.device) + flip_edge_weight = torch.ones(flip_edge_index.size(1), + device=self.device) + + self.flipped_edges = torch.cat((self.flipped_edges, flip_edge_index), + axis=-1) + + if self.is_undirected: + flip_edge_index, flip_edge_weight = to_undirected( + flip_edge_index, flip_edge_weight, num_nodes=self.num_nodes, + reduce='mean') + + edge_index = torch.cat((self._edge_index, flip_edge_index), dim=-1) + edge_weight = torch.cat((self._edge_weight, flip_edge_weight)) + + edge_index, edge_weight = coalesce(edge_index, edge_weight, + num_nodes=self.num_nodes, + reduce='sum') + + mask = torch.isclose(edge_weight, torch.tensor(1.)) + + self._edge_index = edge_index[:, mask] + self._edge_weight = edge_weight[mask] + + # Sample initial search space (Algorithm 2, line 3-4) + self.sample_random_block(step_size) + + # Return debug information + scalars = { + 'number_positive_entries_in_gradient': (gradient > 0).sum().item() + } + return scalars + + def get_flipped_edges(self) -> Tensor: + """Clean up and prepare return flipped edges.""" + return self.flipped_edges diff --git a/greatx/attack/untargeted/utils.py b/greatx/attack/untargeted/utils.py new file mode 100644 index 0000000..d446d9f --- /dev/null +++ b/greatx/attack/untargeted/utils.py @@ -0,0 +1,63 @@ +import torch +from torch import Tensor + + +def project(num_budgets: int, values: Tensor, eps: float = 1e-7) -> Tensor: + r"""Project :obj:`values`: + :math:`num_budgets \ge \sum \Pi_{[0, 1]}(\text{values})`.""" + if torch.clamp(values, 0, 1).sum() > num_budgets: + left = (values - 1).min() + right = values.max() + miu = bisection(values, left, right, num_budgets) + values = values - miu + return torch.clamp(values, min=eps, max=1 - eps) + + +def bisection(edge_weight: Tensor, a: float, b: float, n_pert: int, eps=1e-5, + max_iter=1e3) -> Tensor: + """Bisection search for projection.""" + def shift(offset: float): + return (torch.clamp(edge_weight - offset, 0, 1).sum() - n_pert) + + miu = a + for _ in range(int(max_iter)): + miu = (a + b) / 2 + # Check if middle point is root + if (shift(miu) == 0.0): + break + # Decide the side to repeat the steps + if (shift(miu) * shift(a) < 0): + b = miu + else: + a = miu + if ((b - a) <= eps): + break + return miu + + +def num_possible_edges(n: int, is_undirected_graph: bool) -> int: + """Determine number of possible edges for graph.""" + if is_undirected_graph: + return n * (n - 1) // 2 + else: + return int(n**2) # We filter self-loops later + + +def linear_to_triu_idx(n: int, lin_idx: Tensor) -> Tensor: + """Linear index to upper triangular matrix without diagonal. + This is similar to + https://stackoverflow.com/questions/242711/algorithm-for-index-numbers-of-triangular-matrix-coefficients/28116498#28116498 + with number nodes decremented and col index incremented by one.""" + nn = n * (n - 1) + row_idx = n - 2 - torch.floor( + torch.sqrt(-8 * lin_idx.double() + 4 * nn - 7) / 2.0 - 0.5).long() + col_idx = 1 + lin_idx + row_idx - nn // 2 + torch.div( + (n - row_idx) * (n - row_idx - 1), 2, rounding_mode='floor') + return torch.stack((row_idx, col_idx)) + + +def linear_to_full_idx(n: int, lin_idx: Tensor) -> Tensor: + """Linear index to dense matrix including diagonal.""" + row_idx = torch.div(lin_idx, n, rounding_mode='floor') + col_idx = lin_idx % n + return torch.stack((row_idx, col_idx)) diff --git a/greatx/functional/__init__.py b/greatx/functional/__init__.py index 0a486f5..09ec6c5 100644 --- a/greatx/functional/__init__.py +++ b/greatx/functional/__init__.py @@ -1,7 +1,19 @@ from .dropouts import drop_edge, drop_node, drop_path from .spmm import spmm from .transform import to_dense_adj, to_sparse_adj, to_sparse_tensor +from .losses import (margin_loss, tanh_margin_loss, probability_margin_loss, + masked_cross_entropy) -classes = __all__ = ['to_sparse_tensor', 'to_dense_adj', 'to_sparse_adj', - 'spmm', - 'drop_edge', 'drop_node', 'drop_path'] +classes = __all__ = [ + 'to_sparse_tensor', + 'to_dense_adj', + 'to_sparse_adj', + 'spmm', + 'drop_edge', + 'drop_node', + 'drop_path', + 'margin_loss', + 'tanh_margin_loss', + 'probability_margin_loss', + 'masked_cross_entropy', +] diff --git a/greatx/functional/losses.py b/greatx/functional/losses.py new file mode 100644 index 0000000..7f578fe --- /dev/null +++ b/greatx/functional/losses.py @@ -0,0 +1,107 @@ +import torch +import torch.nn.functional as F +from torch import Tensor + + +def margin_loss(score: Tensor, target: Tensor) -> Tensor: + r"""Margin loss between true score and highest non-target score: + + .. math:: + m = - s_{y} + max_{y' \ne y} s_{y'} + + where :math:`m` is the margin :math:`s` the score and :math:`y` the + target. + + Parameters + ---------- + score : Tensor + some score (e.g. prediction) of shape :obj:`[n_elem, dim]`. + target : LongTensor + the target of shape :obj:`[n_elem]`. + + Returns + ------- + Tensor + the calculated margins + """ + + linear_idx = torch.arange(score.size(0), device=score.device) + true_score = score[linear_idx, target] + + score = score.clone() + score[linear_idx, target] = float('-Inf') + best_non_target_score = score.amax(dim=-1) + + margin = best_non_target_score - true_score + return margin + + +def tanh_margin_loss(prediction: Tensor, target: Tensor) -> Tensor: + r"""Calculate tanh margin loss, a node-classification loss that focuses + on nodes next to decision boundary. + + Parameters + ---------- + prediction : Tensor + prediction of shape :obj:`[n_elem, dim]`. + target : LongTensor + the target of shape :obj:`[n_elem]`. + + Returns + ------- + Tensor + the calculated loss + """ + prediction = F.log_softmax(prediction, dim=-1) + margin = margin_loss(prediction, target) + loss = torch.tanh(margin).mean() + return loss + + +def probability_margin_loss(prediction: Tensor, target: Tensor) -> Tensor: + r"""Calculate probability margin loss, a node-classification loss that + focuses on nodes next to decision boundary. See `Are Defenses for + Graph Neural Networks Robust? + `_ for details. + + Parameters + ---------- + prediction : Tensor + prediction of shape :obj:`[n_elem, dim]`. + target : LongTensor + the target of shape :obj:`[n_elem]`. + + Returns + ------- + Tensor + the calculated loss + """ + prediction = F.softmax(prediction, dim=-1) + margin = margin_loss(prediction, target) + return margin.mean() + + +def masked_cross_entropy(prediction: Tensor, target: Tensor) -> Tensor: + r"""Calculate masked cross entropy loss, a node-classification loss that + focuses on nodes next to decision boundary. + + Parameters + ---------- + prediction : Tensor + prediction of shape :obj:`[n_elem, dim]`. + target : LongTensor + the target of shape :obj:`[n_elem]`. + + Returns + ------- + Tensor + the calculated loss + """ + + is_correct = prediction.argmax(-1) == target + if is_correct.any(): + prediction = prediction[is_correct] + target = target[is_correct] + + loss = F.cross_entropy(prediction, target) + return loss