Merge pull request !3209 from liuxiao93/AddNtags/v0.6.0-beta
| @@ -149,6 +149,8 @@ class _BatchNorm(Cell): | |||||
| def construct(self, x): | def construct(self, x): | ||||
| if self.input_dims == '2d': | if self.input_dims == '2d': | ||||
| _shape_check(self.shape(x)) | _shape_check(self.shape(x)) | ||||
| if self.input_dims == '1d': | |||||
| _shape_check_2d(self.shape(x)) | |||||
| if self.use_batch_statistics is None: | if self.use_batch_statistics is None: | ||||
| flag = self.training | flag = self.training | ||||
| else: | else: | ||||
| @@ -200,6 +202,12 @@ def _channel_check(channel, num_channel): | |||||
| raise ValueError("the input channel is not equal with num_channel") | raise ValueError("the input channel is not equal with num_channel") | ||||
| @constexpr | |||||
| def _shape_check_2d(input_shape): | |||||
| if len(input_shape) != 2: | |||||
| raise ValueError("The input must has 2 dims.") | |||||
| @constexpr | @constexpr | ||||
| def _shape_check(in_shape): | def _shape_check(in_shape): | ||||
| if len(in_shape) != 4: | if len(in_shape) != 4: | ||||
| @@ -980,7 +980,8 @@ def get_bprop_scalar_accumulatenv2(self): | |||||
| dx = () | dx = () | ||||
| for _ in range(len(x)): | for _ in range(len(x)): | ||||
| dx = dx + (dout,) | dx = dx + (dout,) | ||||
| return dx | |||||
| return (dx,) | |||||
| return bprop | return bprop | ||||
| @@ -992,7 +993,7 @@ def get_bprop_scalar_addn(self): | |||||
| dx = () | dx = () | ||||
| for _ in range(len(x)): | for _ in range(len(x)): | ||||
| dx = dx + (dout,) | dx = dx + (dout,) | ||||
| return dx | |||||
| return (dx,) | |||||
| return bprop | return bprop | ||||
| @@ -1671,13 +1671,11 @@ test_case_array_ops = [ | |||||
| ('AddN', { | ('AddN', { | ||||
| 'block': NetForTupleInput(P.AddN()), | 'block': NetForTupleInput(P.AddN()), | ||||
| 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]], | 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]], | ||||
| 'desc_bprop': [[2, 3, 3, 5]], | |||||
| 'skip': ['backward']}), | |||||
| 'desc_bprop': [[2, 3, 3, 5]]}), | |||||
| ('AccumulateNV2', { | ('AccumulateNV2', { | ||||
| 'block': NetForTupleInput(P.AccumulateNV2()), | 'block': NetForTupleInput(P.AccumulateNV2()), | ||||
| 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]], | 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]], | ||||
| 'desc_bprop': [[2, 3, 3, 5]], | |||||
| 'skip': ['backward']}), | |||||
| 'desc_bprop': [[2, 3, 3, 5]]}), | |||||
| ('Shape', { | ('Shape', { | ||||
| 'block': P.Shape(), | 'block': P.Shape(), | ||||
| 'desc_inputs': [[3, 3, 2, 2]], | 'desc_inputs': [[3, 3, 2, 2]], | ||||
| @@ -67,10 +67,10 @@ def test_bn2d(): | |||||
| def test_bn1d(): | def test_bn1d(): | ||||
| """ut of nn.BatchNorm1d""" | """ut of nn.BatchNorm1d""" | ||||
| bn = nn.BatchNorm1d(3) | bn = nn.BatchNorm1d(3) | ||||
| input_data = Tensor(np.random.randint(0, 1, [1, 3, 100, 100]).astype(np.float32)) | |||||
| input_data = Tensor(np.random.randint(0, 1, [1, 3]).astype(np.float32)) | |||||
| output = bn(input_data) | output = bn(input_data) | ||||
| output_np = output.asnumpy() | output_np = output.asnumpy() | ||||
| assert isinstance(output_np[0][0][0][0], (np.float32, np.float64)) | |||||
| assert isinstance(output_np[0][0], (np.float32, np.float64)) | |||||
| def test_bn2d_train(): | def test_bn2d_train(): | ||||