You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Unit test on mindspore.explainer._utils."""
  16. import numpy as np
  17. import pytest
  18. import mindspore as ms
  19. import mindspore.nn as nn
  20. from mindspore import context
  21. from mindspore.explainer._utils import (
  22. ForwardProbe,
  23. rank_pixels,
  24. retrieve_layer,
  25. retrieve_layer_by_name)
  26. from mindspore.explainer.explanation._attribution._backprop.backprop_utils import GradNet, get_bp_weights
  27. context.set_context(mode=context.PYNATIVE_MODE)
  28. class CustomNet(nn.Cell):
  29. """Simple net for test."""
  30. def __init__(self):
  31. super(CustomNet, self).__init__()
  32. self.fc1 = nn.Dense(10, 10)
  33. self.fc2 = nn.Dense(10, 10)
  34. self.fc3 = nn.Dense(10, 10)
  35. self.fc4 = nn.Dense(10, 10)
  36. def construct(self, inputs):
  37. out = self.fc1(inputs)
  38. out = self.fc2(out)
  39. out = self.fc3(out)
  40. out = self.fc4(out)
  41. return out
  42. @pytest.mark.level0
  43. @pytest.mark.platform_arm_ascend_training
  44. @pytest.mark.platform_x86_ascend_training
  45. @pytest.mark.env_onecard
  46. def test_rank_pixels():
  47. """Test on rank_pixels."""
  48. saliency = np.array([[4., 3., 1.], [5., 9., 1.]])
  49. descending_target = np.array([[0, 1, 2], [1, 0, 2]])
  50. ascending_target = np.array([[2, 1, 0], [1, 2, 0]])
  51. descending_rank = rank_pixels(saliency)
  52. ascending_rank = rank_pixels(saliency, descending=False)
  53. assert (descending_rank - descending_target).any() == 0
  54. assert (ascending_rank - ascending_target).any() == 0
  55. @pytest.mark.level0
  56. @pytest.mark.platform_arm_ascend_training
  57. @pytest.mark.platform_x86_ascend_training
  58. @pytest.mark.env_onecard
  59. def test_retrieve_layer_by_name():
  60. """Test on rank_pixels."""
  61. model = CustomNet()
  62. target_layer_name = 'fc3'
  63. target_layer = retrieve_layer_by_name(model, target_layer_name)
  64. assert target_layer is model.fc3
  65. @pytest.mark.level0
  66. @pytest.mark.platform_arm_ascend_training
  67. @pytest.mark.platform_x86_ascend_training
  68. @pytest.mark.env_onecard
  69. def test_retrieve_layer_by_name_no_name():
  70. """Test on retrieve layer."""
  71. model = CustomNet()
  72. target_layer = retrieve_layer_by_name(model, '')
  73. assert target_layer is model
  74. @pytest.mark.level0
  75. @pytest.mark.platform_arm_ascend_training
  76. @pytest.mark.platform_x86_ascend_training
  77. @pytest.mark.env_onecard
  78. def test_forward_probe():
  79. """Test case for ForwardProbe."""
  80. model = CustomNet()
  81. model.set_grad()
  82. inputs = np.random.random((1, 10))
  83. inputs = ms.Tensor(inputs, ms.float32)
  84. gt_activation = model.fc3(model.fc2(model.fc1(inputs))).asnumpy()
  85. targets = 1
  86. weights = get_bp_weights(model, inputs, targets=targets)
  87. gradnet = GradNet(model)
  88. grad_before_probe = gradnet(inputs, weights).asnumpy()
  89. # Probe forward tensor
  90. saliency_layer = retrieve_layer(model, 'fc3')
  91. with ForwardProbe(saliency_layer) as probe:
  92. grad_after_probe = gradnet(inputs, weights).asnumpy()
  93. activation = probe.value.asnumpy()
  94. grad_after_unprobe = gradnet(inputs, weights).asnumpy()
  95. assert np.array_equal(gt_activation, activation)
  96. assert np.array_equal(grad_before_probe, grad_after_probe)
  97. assert np.array_equal(grad_before_probe, grad_after_unprobe)
  98. assert probe.value is None