You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_auto_grad.py 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore.nn as nn
  17. from mindspore import context
  18. from mindspore import Tensor
  19. from mindspore.ops import operations as P
  20. from mindspore.ops import composite as C
  21. grad_all = C.GradOperation(get_all=True)
  22. class CropAndResizeNet(nn.Cell):
  23. def __init__(self, crop_size):
  24. super(CropAndResizeNet, self).__init__()
  25. self.crop_and_resize = P.CropAndResize()
  26. self.crop_size = crop_size
  27. def construct(self, x, boxes, box_indices):
  28. return self.crop_and_resize(x, boxes, box_indices, self.crop_size)
  29. def bprop(self, x, boxes, box_indices, out, dout):
  30. return x, boxes, box_indices
  31. class TestUserDefinedBpropNet(nn.Cell):
  32. def __init__(self, in_channel, out_channel):
  33. super(TestUserDefinedBpropNet, self).__init__()
  34. self.relu = nn.ReLU()
  35. self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=2, stride=1, has_bias=False,
  36. weight_init='ones', pad_mode='same')
  37. self.crop = CropAndResizeNet((10, 10))
  38. self.boxes = Tensor(np.ones((128, 4)).astype(np.float32))
  39. self.box_indices = Tensor(np.ones((128,)).astype(np.int32))
  40. def construct(self, x):
  41. x = self.relu(x)
  42. x = self.conv(x)
  43. x = self.crop(x, self.boxes, self.box_indices)
  44. return x
  45. class TestUserDefinedBpropGradNet(nn.Cell):
  46. def __init__(self, net):
  47. super(TestUserDefinedBpropGradNet, self).__init__()
  48. self.net = net
  49. def construct(self, x):
  50. return grad_all(self.net)(x)
  51. def test_user_defined_bprop():
  52. context.set_context(mode=context.GRAPH_MODE)
  53. net = TestUserDefinedBpropNet(3, 10)
  54. grad_net = TestUserDefinedBpropGradNet(net)
  55. x = Tensor(np.ones((128, 3, 12, 12)).astype(np.float32))
  56. grad_net(x)