You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Unit test on mindspore.explainer._utils."""
  16. import numpy as np
  17. import pytest
  18. import mindspore as ms
  19. import mindspore.nn as nn
  20. from mindspore.explainer._utils import (
  21. ForwardProbe,
  22. rank_pixels,
  23. retrieve_layer,
  24. retrieve_layer_by_name)
  25. from mindspore.explainer.explanation._attribution._backprop.backprop_utils import GradNet, get_bp_weights
  26. class CustomNet(nn.Cell):
  27. """Simple net for test."""
  28. def __init__(self):
  29. super(CustomNet, self).__init__()
  30. self.fc1 = nn.Dense(10, 10)
  31. self.fc2 = nn.Dense(10, 10)
  32. self.fc3 = nn.Dense(10, 10)
  33. self.fc4 = nn.Dense(10, 10)
  34. def construct(self, inputs):
  35. out = self.fc1(inputs)
  36. out = self.fc2(out)
  37. out = self.fc3(out)
  38. out = self.fc4(out)
  39. return out
  40. @pytest.mark.level0
  41. @pytest.mark.platform_arm_ascend_training
  42. @pytest.mark.platform_x86_ascend_training
  43. @pytest.mark.env_onecard
  44. def test_rank_pixels():
  45. """Test on rank_pixels."""
  46. saliency = np.array([[4., 3., 1.], [5., 9., 1.]])
  47. descending_target = np.array([[0, 1, 2], [1, 0, 2]])
  48. ascending_target = np.array([[2, 1, 0], [1, 2, 0]])
  49. descending_rank = rank_pixels(saliency)
  50. ascending_rank = rank_pixels(saliency, descending=False)
  51. assert (descending_rank - descending_target).any() == 0
  52. assert (ascending_rank - ascending_target).any() == 0
  53. @pytest.mark.level0
  54. @pytest.mark.platform_arm_ascend_training
  55. @pytest.mark.platform_x86_ascend_training
  56. @pytest.mark.env_onecard
  57. def test_retrieve_layer_by_name():
  58. """Test on rank_pixels."""
  59. model = CustomNet()
  60. target_layer_name = 'fc3'
  61. target_layer = retrieve_layer_by_name(model, target_layer_name)
  62. assert target_layer is model.fc3
  63. @pytest.mark.level0
  64. @pytest.mark.platform_arm_ascend_training
  65. @pytest.mark.platform_x86_ascend_training
  66. @pytest.mark.env_onecard
  67. def test_retrieve_layer_by_name_no_name():
  68. """Test on retrieve layer."""
  69. model = CustomNet()
  70. target_layer = retrieve_layer_by_name(model, '')
  71. assert target_layer is model
  72. @pytest.mark.level0
  73. @pytest.mark.platform_arm_ascend_training
  74. @pytest.mark.platform_x86_ascend_training
  75. @pytest.mark.env_onecard
  76. def test_forward_probe():
  77. """Test case for ForwardProbe."""
  78. model = CustomNet()
  79. model.set_grad()
  80. inputs = np.random.random((1, 10))
  81. inputs = ms.Tensor(inputs, ms.float32)
  82. gt_activation = model.fc3(model.fc2(model.fc1(inputs))).asnumpy()
  83. targets = 1
  84. weights = get_bp_weights(model, inputs, targets=targets)
  85. gradnet = GradNet(model)
  86. grad_before_probe = gradnet(inputs, weights).asnumpy()
  87. # Probe forward tensor
  88. saliency_layer = retrieve_layer(model, 'fc3')
  89. with ForwardProbe(saliency_layer) as probe:
  90. grad_after_probe = gradnet(inputs, weights).asnumpy()
  91. activation = probe.value.asnumpy()
  92. grad_after_unprobe = gradnet(inputs, weights).asnumpy()
  93. assert np.array_equal(gt_activation, activation)
  94. assert np.array_equal(grad_before_probe, grad_after_probe)
  95. assert np.array_equal(grad_before_probe, grad_after_unprobe)
  96. assert probe.value is None