Browse Source

!9003 Fix mistakes in docstring.

From: @yuhanshi
Reviewed-by: @wuxuejian,@ouwenchang
Signed-off-by: @wuxuejian
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
829cc849a0
4 changed files with 26 additions and 43 deletions
  1. +4
    -5
      mindspore/explainer/benchmark/_attribution/class_sensitivity.py
  2. +5
    -17
      mindspore/explainer/benchmark/_attribution/robustness.py
  3. +4
    -4
      mindspore/explainer/explanation/_attribution/_perturbation/ablation.py
  4. +13
    -17
      mindspore/explainer/explanation/_attribution/_perturbation/occlusion.py

+ 4
- 5
mindspore/explainer/benchmark/_attribution/class_sensitivity.py View File

@@ -16,10 +16,8 @@


import numpy as np import numpy as np


from mindspore import Tensor
from .metric import LabelAgnosticMetric from .metric import LabelAgnosticMetric
from ... import _operators as ops from ... import _operators as ops
from ...explanation._attribution.attribution import Attribution
from ..._utils import calc_correlation from ..._utils import calc_correlation




@@ -35,12 +33,12 @@ class ClassSensitivity(LabelAgnosticMetric):


""" """


def evaluate(self, explainer: Attribution, inputs: Tensor) -> np.ndarray:
def evaluate(self, explainer, inputs):
""" """
Evaluate class sensitivity on a single data sample. Evaluate class sensitivity on a single data sample.


Args: Args:
explainer (Attribution): The explainer to be evaluated, see `mindspore.explainer.explanation`.
explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`.
inputs (Tensor): A data sample, a 4D tensor of shape :math:`(N, C, H, W)`. inputs (Tensor): A data sample, a 4D tensor of shape :math:`(N, C, H, W)`.


Returns: Returns:
@@ -49,7 +47,8 @@ class ClassSensitivity(LabelAgnosticMetric):
Examples: Examples:
>>> import mindspore as ms >>> import mindspore as ms
>>> from mindspore.explainer.explanation import Gradient >>> from mindspore.explainer.explanation import Gradient
>>> gradient = Gradient()
>>> model = resnet(10)
>>> gradient = Gradient(model)
>>> x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) >>> x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
>>> class_sensitivity = ClassSensitivity() >>> class_sensitivity = ClassSensitivity()
>>> res = class_sensitivity.evaluate(gradient, x) >>> res = class_sensitivity.evaluate(gradient, x)


+ 5
- 17
mindspore/explainer/benchmark/_attribution/robustness.py View File

@@ -14,21 +14,14 @@
# ============================================================================ # ============================================================================
"""Robustness.""" """Robustness."""


from typing import Optional, Union

import numpy as np import numpy as np


import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor
from mindspore import log from mindspore import log
from .metric import LabelSensitiveMetric from .metric import LabelSensitiveMetric
from ...explanation._attribution import Attribution
from ...explanation._attribution._perturbation.replacement import RandomPerturb from ...explanation._attribution._perturbation.replacement import RandomPerturb


_Array = np.ndarray
_Label = Union[ms.Tensor, int]



class Robustness(LabelSensitiveMetric): class Robustness(LabelSensitiveMetric):
""" """
@@ -39,12 +32,12 @@ class Robustness(LabelSensitiveMetric):
num_labels (int): Number of classes in the dataset. num_labels (int): Number of classes in the dataset.


Examples: Examples:
>>> from mindspore.explainer.benchmark import Robustness
>>> num_labels = 100
>>> robustness = Robustness(num_labels)
>>> from mindspore.explainer.benchmark import Robustness
>>> num_labels = 100
>>> robustness = Robustness(num_labels)
""" """


def __init__(self, num_labels: int, activation_fn=nn.Softmax()):
def __init__(self, num_labels, activation_fn=nn.Softmax()):
super().__init__(num_labels) super().__init__(num_labels)


self._perturb = RandomPerturb() self._perturb = RandomPerturb()
@@ -52,12 +45,7 @@ class Robustness(LabelSensitiveMetric):
self._threshold = 0.1 # threshold to generate perturbation self._threshold = 0.1 # threshold to generate perturbation
self._activation_fn = activation_fn self._activation_fn = activation_fn


def evaluate(self,
explainer: Attribution,
inputs: Tensor,
targets: _Label,
saliency: Optional[Tensor] = None
) -> _Array:
def evaluate(self, explainer, inputs, targets, saliency=None):
""" """
Evaluate robustness on single sample. Evaluate robustness on single sample.




+ 4
- 4
mindspore/explainer/explanation/_attribution/_perturbation/ablation.py View File

@@ -48,7 +48,7 @@ class Ablation:
inputs (np.ndarray): Input array to perturb. The first dim of inputs is assumed to be the batch size, i.e., inputs (np.ndarray): Input array to perturb. The first dim of inputs is assumed to be the batch size, i.e.,
number of samples. number of samples.
reference (np.ndarray or float): Array of values to replace the elements in the original inputs. The shape reference (np.ndarray or float): Array of values to replace the elements in the original inputs. The shape
of reference must math the inputs. If scalar is provided, the perturbed elements will be assigned the
of reference must match the inputs. If scalar is provided, the perturbed elements will be assigned the
given value.. given value..
masks (np.ndarray): Several boolean array to mark the perturbed positions. True marks the pixels to be masks (np.ndarray): Several boolean array to mark the perturbed positions. True marks the pixels to be
perturbed, otherwise the pixels will be kept. The shape of masks is assumed to be perturbed, otherwise the pixels will be kept. The shape of masks is assumed to be
@@ -134,9 +134,9 @@ class AblationWithSaliency(Ablation):
saliency is expected to be: [batch_size, optional(num_channels), *spatial_size]. If multi-channel saliency is expected to be: [batch_size, optional(num_channels), *spatial_size]. If multi-channel
saliency is provided, an averaged saliency will be taken to calculate pixel order in spatial dimension. saliency is provided, an averaged saliency will be taken to calculate pixel order in spatial dimension.
num_channels (optional[int]): Number of channels of the input data. In order to match the shape of inputs, num_channels (optional[int]): Number of channels of the input data. In order to match the shape of inputs,
num_channels should be provided when input data have channels dimension, even if num_channel. If None is
provided, the inputs is assumed to be no-channel data, and the generated mask will have no channel
dimension. Default: None.
num_channels should be provided when input data have channels dimension, even if num_channel is 1.
If None is provided, the inputs is assumed to be no-channel data, and the generated mask will have
no channel dimension. Default: None.


Return: Return:
mask (np.ndarray): boolen mask for generate perturbations. mask (np.ndarray): boolen mask for generate perturbations.


+ 13
- 17
mindspore/explainer/explanation/_attribution/_perturbation/occlusion.py View File

@@ -15,7 +15,6 @@
"""Occlusion explainer.""" """Occlusion explainer."""


import math import math
from typing import Tuple, Union


import numpy as np import numpy as np
from numpy.lib.stride_tricks import as_strided from numpy.lib.stride_tricks import as_strided
@@ -23,15 +22,11 @@ from numpy.lib.stride_tricks import as_strided
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn import Cell
from .ablation import Ablation from .ablation import Ablation
from .perturbation import PerturbationAttribution from .perturbation import PerturbationAttribution
from .replacement import Constant from .replacement import Constant
from ...._utils import abs_max from ...._utils import abs_max


_Array = np.ndarray
_Label = Union[int, Tensor]



def _generate_patches(array, window_size, stride): def _generate_patches(array, window_size, stride):
"""View as windows.""" """View as windows."""
@@ -67,25 +62,26 @@ class Occlusion(PerturbationAttribution):
network (Cell): Specify the black-box model to be explained. network (Cell): Specify the black-box model to be explained.


Inputs: Inputs:
inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer.
If it is a 1D tensor, its length should be the same as `inputs`.
inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer.
If it is a 1D tensor, its length should be the same as `inputs`.


Outputs: Outputs:
Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`.
Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`.


Example: Example:
>>> from mindspore.explainer.explanation import Occlusion >>> from mindspore.explainer.explanation import Occlusion
>>> net = resnet50(10)
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
>>> network = resnet50(10)
>>> param_dict = load_checkpoint("resnet50.ckpt") >>> param_dict = load_checkpoint("resnet50.ckpt")
>>> load_param_into_net(net, param_dict)
>>> occlusion = Occlusion(net)
>>> x = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
>>> load_param_into_net(network, param_dict)
>>> occlusion = Occlusion(network)
>>> x = Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
>>> label = 1 >>> label = 1
>>> saliency = occlusion(x, label) >>> saliency = occlusion(x, label)
""" """


def __init__(self, network: Cell, activation_fn: Cell = nn.Softmax()):
def __init__(self, network, activation_fn=nn.Softmax()):
super().__init__(network, activation_fn) super().__init__(network, activation_fn)


self._ablation = Ablation(perturb_mode='Deletion') self._ablation = Ablation(perturb_mode='Deletion')
@@ -94,7 +90,7 @@ class Occlusion(PerturbationAttribution):
self._num_sample_per_dim = 32 # specify the number of perturbations each dimension. self._num_sample_per_dim = 32 # specify the number of perturbations each dimension.
self._num_per_eval = 32 # number of perturbations each evaluation step. self._num_per_eval = 32 # number of perturbations each evaluation step.


def __call__(self, inputs: Tensor, targets: _Label) -> Tensor:
def __call__(self, inputs, targets):
"""Call function for 'Occlusion'.""" """Call function for 'Occlusion'."""
self._verify_data(inputs, targets) self._verify_data(inputs, targets)


@@ -145,11 +141,11 @@ class Occlusion(PerturbationAttribution):
outputs_diff.reshape(outputs_diff.shape + (1,) * (len(masks.shape) - 2)) * masks).sum(axis=1).clip(1e-6) outputs_diff.reshape(outputs_diff.shape + (1,) * (len(masks.shape) - 2)) * masks).sum(axis=1).clip(1e-6)
weights += masks.sum(axis=1) weights += masks.sum(axis=1)


attribution = self._aggregation_fn(ms.Tensor(total_attribution / weights))
attribution = self._aggregation_fn(Tensor(total_attribution / weights))
return attribution return attribution


@staticmethod @staticmethod
def _generate_masks(inputs: Tensor, window_size: Tuple[int, ...], strides: Tuple[int, ...]) -> _Array:
def _generate_masks(inputs, window_size, strides):
"""Generate masks to perturb contiguous regions.""" """Generate masks to perturb contiguous regions."""
total_dim = np.prod(inputs.shape[1:]).item() total_dim = np.prod(inputs.shape[1:]).item()
template = np.arange(total_dim).reshape(inputs.shape[1:]) template = np.arange(total_dim).reshape(inputs.shape[1:])


Loading…
Cancel
Save