You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_debug_location.py 5.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.nn as nn
  18. from mindspore import Tensor, Parameter
  19. from mindspore import context
  20. from mindspore.common import dtype as mstype
  21. from mindspore.nn.optim import Momentum
  22. from mindspore.nn.wrap.cell_wrapper import WithLossCell
  23. from mindspore.nn.wrap.loss_scale import TrainOneStepWithLossScaleCell
  24. from mindspore.ops import functional as F
  25. from mindspore.ops import operations as P
  26. from mindspore.ops._grad.grad_base import bprop_getters
  27. from mindspore.ops._grad.grad_math_ops import binop_grad_common
  28. from mindspore.ops._utils import get_broadcast_shape
  29. from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register
  30. from mindspore.train.loss_scale_manager import DynamicLossScaleManager
  31. context.set_context(mode=context.GRAPH_MODE)
  32. class MockNeg(PrimitiveWithInfer):
  33. @prim_attr_register
  34. def __init__(self):
  35. """init MockNeg"""
  36. self.init_prim_io_names(inputs=['x'], outputs=['y'])
  37. def infer_shape(self, input_x):
  38. return input_x
  39. def infer_dtype(self, input_x):
  40. raise TypeError("InferError")
  41. # return input_x
  42. class MockSub(PrimitiveWithInfer):
  43. @prim_attr_register
  44. def __init__(self):
  45. """init MockSub"""
  46. self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
  47. def infer_shape(self, x_shape, y_shape):
  48. return get_broadcast_shape(x_shape, y_shape)
  49. def infer_dtype(self, x_dtype, y_dtype):
  50. return x_dtype
  51. @bprop_getters.register(MockSub)
  52. def get_bprop_mock_sub(self):
  53. """Grad definition for `MockSub` operation."""
  54. neg_func = MockNeg()
  55. def bprop(x, y, out, dout):
  56. return binop_grad_common(x, y, dout, neg_func(dout))
  57. return bprop
  58. class Net(nn.Cell):
  59. def __init__(self, in_features, out_features):
  60. super(Net, self).__init__()
  61. self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
  62. self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias")
  63. self.matmul = P.MatMul()
  64. self.add = P.TensorAdd()
  65. def construct(self, input_):
  66. output = self.add(self.matmul(input_, self.weight), self.bias)
  67. return output
  68. class NetFP16(nn.Cell):
  69. def __init__(self, in_features, out_features):
  70. super(NetFP16, self).__init__()
  71. self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
  72. self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias")
  73. self.matmul = P.MatMul()
  74. self.add = P.TensorAdd()
  75. self.cast = P.Cast()
  76. def construct(self, input_):
  77. output = self.cast(
  78. self.add(self.matmul(self.cast(input_, mstype.float16), self.cast(self.weight, mstype.float16)),
  79. self.cast(self.bias, mstype.float16)), mstype.float32)
  80. return output
  81. def get_axis(x):
  82. shape = F.shape(x)
  83. length = F.tuple_len(shape)
  84. perm = F.make_range(0, length)
  85. return perm
  86. class MSELoss(nn.Cell):
  87. def __init__(self):
  88. super(MSELoss, self).__init__()
  89. self.reduce_sum = P.ReduceSum()
  90. self.square = P.Square()
  91. self.reduce_mean = P.ReduceMean()
  92. self.sub = MockSub()
  93. def construct(self, data, label):
  94. diff = self.sub(data, label)
  95. return self.reduce_mean(self.square(diff), get_axis(diff))
  96. class NegCell(nn.Cell):
  97. def __init__(self):
  98. super(NegCell, self).__init__()
  99. self.neg = MockNeg()
  100. def construct(self, x):
  101. return self.neg(x)
  102. class Net3(nn.Cell):
  103. def __init__(self):
  104. super().__init__()
  105. self.tuple = (NegCell(), nn.ReLU())
  106. def construct(self, x):
  107. for op in self.tuple:
  108. x = op(x)
  109. return x
  110. def test_op_forward_infererror():
  111. input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
  112. input_me = Tensor(input_np)
  113. net = Net3()
  114. with pytest.raises(TypeError):
  115. net(input_me)
  116. class SequenceNet(nn.Cell):
  117. def __init__(self):
  118. super().__init__()
  119. self.seq = nn.SequentialCell([nn.AvgPool2d(3, 1), nn.ReLU(), nn.Flatten()])
  120. def construct(self, x):
  121. x = self.seq(x) + bbb
  122. return x
  123. def test_sequential_resolve_error():
  124. input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
  125. input_me = Tensor(input_np)
  126. net = SequenceNet()
  127. with pytest.raises(RuntimeError):
  128. net(input_me)
  129. def test_compile_grad_error():
  130. inputs = Tensor(np.ones([16, 16]).astype(np.float32))
  131. label = Tensor(np.zeros([16, 16]).astype(np.float32))
  132. lr = Tensor(np.ones([1], np.float32) * 0.1)
  133. net = NetFP16(16, 16)
  134. loss = MSELoss()
  135. optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9)
  136. net_with_loss = WithLossCell(net, loss)
  137. scale_manager = DynamicLossScaleManager()
  138. update_cell = scale_manager.get_update_cell()
  139. train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=update_cell)
  140. train_network.set_train()
  141. with pytest.raises(TypeError) as e:
  142. train_network(inputs, label)
  143. print(e)