Browse Source

modify ResizeNearestNeighborV2D

tags/v0.3.0-alpha
chang zherui 5 years ago
parent
commit
bd13f9ba33
7 changed files with 35 additions and 47 deletions
  1. +0
    -2
      mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc
  2. +0
    -3
      mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h
  3. +28
    -31
      mindspore/ccsrc/transform/op_declare.cc
  4. +1
    -3
      mindspore/ops/_grad/grad_nn_ops.py
  5. +2
    -2
      mindspore/ops/operations/_grad_ops.py
  6. +3
    -5
      mindspore/ops/operations/nn_ops.py
  7. +1
    -1
      tests/ut/python/ops/test_ops.py

+ 0
- 2
mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc View File

@@ -55,7 +55,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm,
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormGpuKernel, float) FusedBatchNormGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(BatchNorm, MS_REG_GPU_KERNEL_ONE(BatchNorm,
@@ -69,7 +68,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm,
.AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16), .AddOutputAttr(kNumberTypeFloat16),
FusedBatchNormGpuKernel, half) FusedBatchNormGpuKernel, half)
} // namespace kernel } // namespace kernel


+ 0
- 3
mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h View File

@@ -156,9 +156,6 @@ class FusedBatchNormGpuKernel : public GpuKernel {
output_size_list_.push_back(para_size); // running variance output_size_list_.push_back(para_size); // running variance
output_size_list_.push_back(para_size); // save mean output_size_list_.push_back(para_size); // save mean
output_size_list_.push_back(para_size); // save variance output_size_list_.push_back(para_size); // save variance
if (!is_train_) {
output_size_list_.push_back(para_size); // reserve
}
return; return;
} }




+ 28
- 31
mindspore/ccsrc/transform/op_declare.cc View File

@@ -154,14 +154,14 @@ ATTR_MAP(BatchNorm) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::str
OUTPUT_MAP(BatchNorm) = {{0, OUTPUT_DESC(y)}, OUTPUT_MAP(BatchNorm) = {{0, OUTPUT_DESC(y)},
{1, OUTPUT_DESC(batch_mean)}, {1, OUTPUT_DESC(batch_mean)},
{2, OUTPUT_DESC(batch_variance)}, {2, OUTPUT_DESC(batch_variance)},
{3, OUTPUT_DESC(reserve_space_1)},
{4, OUTPUT_DESC(reserve_space_2)},
{5, OUTPUT_DESC(reserve_space_3)}};
{4, OUTPUT_DESC(reserve_space_2)}};


// BatchNormGrad // BatchNormGrad
INPUT_MAP(BatchNormGrad) = {{1, INPUT_DESC(y_backprop)}, {2, INPUT_DESC(x)},
{3, INPUT_DESC(scale)}, {4, INPUT_DESC(reserve_space_1)},
{5, INPUT_DESC(reserve_space_2)}, {6, INPUT_DESC(reserve_space_3)}};
INPUT_MAP(BatchNormGrad) = {{1, INPUT_DESC(y_backprop)},
{2, INPUT_DESC(x)},
{3, INPUT_DESC(scale)},
{4, INPUT_DESC(reserve_space_1)},
{5, INPUT_DESC(reserve_space_2)}};
ATTR_MAP(BatchNormGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}, ATTR_MAP(BatchNormGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())},
{"epsilon", ATTR_DESC(epsilon, AnyTraits<float>())}, {"epsilon", ATTR_DESC(epsilon, AnyTraits<float>())},
{"is_training", ATTR_DESC(is_training, AnyTraits<bool>())}}; {"is_training", ATTR_DESC(is_training, AnyTraits<bool>())}};
@@ -266,11 +266,6 @@ INPUT_MAP(GatherV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_D
ATTR_MAP(GatherV2) = EMPTY_ATTR_MAP; ATTR_MAP(GatherV2) = EMPTY_ATTR_MAP;
OUTPUT_MAP(GatherV2) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(GatherV2) = {{0, OUTPUT_DESC(y)}};


// ReduceSum
INPUT_MAP(ReduceSum) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
ATTR_MAP(ReduceSum) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ReduceSum) = {{0, OUTPUT_DESC(y)}};

// ReduceSumD // ReduceSumD
INPUT_MAP(ReduceSumD) = {{1, INPUT_DESC(x)}}; INPUT_MAP(ReduceSumD) = {{1, INPUT_DESC(x)}};
INPUT_ATTR_MAP(ReduceSumD) = { INPUT_ATTR_MAP(ReduceSumD) = {
@@ -451,17 +446,17 @@ INPUT_MAP(Iou) = {{1, INPUT_DESC(bboxes)}, {2, INPUT_DESC(gtboxes)}};
ATTR_MAP(Iou) = {{"mode", ATTR_DESC(mode, AnyTraits<std::string>())}}; ATTR_MAP(Iou) = {{"mode", ATTR_DESC(mode, AnyTraits<std::string>())}};
OUTPUT_MAP(Iou) = {{0, OUTPUT_DESC(overlap)}}; OUTPUT_MAP(Iou) = {{0, OUTPUT_DESC(overlap)}};


// ResizeNearestNeighborD
INPUT_MAP(ResizeNearestNeighborD) = {{1, INPUT_DESC(x)}};
ATTR_MAP(ResizeNearestNeighborD) = {
// ResizeNearestNeighborV2D
INPUT_MAP(ResizeNearestNeighborV2D) = {{1, INPUT_DESC(x)}};
ATTR_MAP(ResizeNearestNeighborV2D) = {
{"size", ATTR_DESC(size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, {"size", ATTR_DESC(size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}}; {"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}};
OUTPUT_MAP(ResizeNearestNeighborD) = {{0, OUTPUT_DESC(y)}};
OUTPUT_MAP(ResizeNearestNeighborV2D) = {{0, OUTPUT_DESC(y)}};


// ResizeNearestNeighborGrad
INPUT_MAP(ResizeNearestNeighborGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(size)}};
ATTR_MAP(ResizeNearestNeighborGrad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}};
OUTPUT_MAP(ResizeNearestNeighborGrad) = {{0, OUTPUT_DESC(y)}};
// ResizeNearestNeighborV2Grad
INPUT_MAP(ResizeNearestNeighborV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(size)}};
ATTR_MAP(ResizeNearestNeighborV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}};
OUTPUT_MAP(ResizeNearestNeighborV2Grad) = {{0, OUTPUT_DESC(y)}};


// ApplyAdam // ApplyAdam
INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)},
@@ -486,17 +481,17 @@ INPUT_MAP(Relu6Grad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}};
ATTR_MAP(Relu6Grad) = EMPTY_ATTR_MAP; ATTR_MAP(Relu6Grad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Relu6Grad) = {{0, OUTPUT_DESC(backprops)}}; OUTPUT_MAP(Relu6Grad) = {{0, OUTPUT_DESC(backprops)}};


// ResizeBilinearGrad
INPUT_MAP(ResizeBilinearGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}};
ATTR_MAP(ResizeBilinearGrad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}};
OUTPUT_MAP(ResizeBilinearGrad) = {{0, OUTPUT_DESC(y)}};
// ResizeBilinearV2Grad
INPUT_MAP(ResizeBilinearV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}};
ATTR_MAP(ResizeBilinearV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}};
OUTPUT_MAP(ResizeBilinearV2Grad) = {{0, OUTPUT_DESC(y)}};


// ResizeBilinearD
INPUT_MAP(ResizeBilinearD) = {{1, INPUT_DESC(x)}};
ATTR_MAP(ResizeBilinearD) = {
// ResizeBilinearV2D
INPUT_MAP(ResizeBilinearV2D) = {{1, INPUT_DESC(x)}};
ATTR_MAP(ResizeBilinearV2D) = {
{"size", ATTR_DESC(size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, {"size", ATTR_DESC(size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}}; {"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}};
OUTPUT_MAP(ResizeBilinearD) = {{0, OUTPUT_DESC(y)}};
OUTPUT_MAP(ResizeBilinearV2D) = {{0, OUTPUT_DESC(y)}};


// ZerosLike // ZerosLike
INPUT_MAP(ZerosLike) = {{1, INPUT_DESC(x)}}; INPUT_MAP(ZerosLike) = {{1, INPUT_DESC(x)}};
@@ -609,10 +604,12 @@ ATTR_MAP(ArgMinWithValue) = {{"axis", ATTR_DESC(dimension, AnyTraits<int>())},
{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}}; {"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ArgMinWithValue) = {{0, OUTPUT_DESC(indice)}, {1, OUTPUT_DESC(values)}}; OUTPUT_MAP(ArgMinWithValue) = {{0, OUTPUT_DESC(indice)}, {1, OUTPUT_DESC(values)}};


// ReduceAll
INPUT_MAP(ReduceAll) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
ATTR_MAP(ReduceAll) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ReduceAll) = {{0, OUTPUT_DESC(y)}}
// ReduceAllD
INPUT_MAP(ReduceAllD) = {{1, INPUT_DESC(x)}};
INPUT_ATTR_MAP(ReduceAllD) = {
{2, ATTR_DESC(axis, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(ReduceAllD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ReduceAllD) = {{0, OUTPUT_DESC(y)}};


// ReduceMeanD // ReduceMeanD
INPUT_MAP(ReduceMeanD) = {{1, INPUT_DESC(x)}}; INPUT_MAP(ReduceMeanD) = {{1, INPUT_DESC(x)}};


+ 1
- 3
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -356,12 +356,10 @@ def get_bprop_batch_norm(self):
if is_training: if is_training:
saved_reserve_1 = out[3] saved_reserve_1 = out[3]
saved_reserve_2 = out[4] saved_reserve_2 = out[4]
saved_reserve_3 = out[5]
else: else:
saved_reserve_1 = mean saved_reserve_1 = mean
saved_reserve_2 = variance saved_reserve_2 = variance
saved_reserve_3 = variance
out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2, saved_reserve_3)
out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2)
dx = out[0] dx = out[0]
dscale = out[1] dscale = out[1]
dbias = out[2] dbias = out[2]


+ 2
- 2
mindspore/ops/operations/_grad_ops.py View File

@@ -69,11 +69,11 @@ class BatchNormGrad(PrimitiveWithInfer):
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT) self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT)
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")


def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape):
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape):
validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape)
return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape) return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape)


def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type, reserve_3_type):
def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type):
return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type)






+ 3
- 5
mindspore/ops/operations/nn_ops.py View File

@@ -537,7 +537,6 @@ class BatchNorm(PrimitiveWithInfer):
- **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`. - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
- **reserve_space_1** (Tensor) - Tensor of shape :math:`(C,)`. - **reserve_space_1** (Tensor) - Tensor of shape :math:`(C,)`.
- **reserve_space_2** (Tensor) - Tensor of shape :math:`(C,)`. - **reserve_space_2** (Tensor) - Tensor of shape :math:`(C,)`.
- **reserve_space_3** (Tensor) - Tensor of shape :math:`(C,)`.
""" """


@prim_attr_register @prim_attr_register
@@ -546,8 +545,7 @@ class BatchNorm(PrimitiveWithInfer):
validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2',
'reserve_space_3'])
outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])


def infer_shape(self, input_x, scale, bias, mean, variance): def infer_shape(self, input_x, scale, bias, mean, variance):
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
@@ -557,7 +555,7 @@ class BatchNorm(PrimitiveWithInfer):
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
return (input_x, scale, scale, scale, scale, scale)
return (input_x, scale, scale, scale, scale)


def infer_dtype(self, input_x, scale, bias, mean, variance): def infer_dtype(self, input_x, scale, bias, mean, variance):
validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name)
@@ -570,7 +568,7 @@ class BatchNorm(PrimitiveWithInfer):
else: else:
args_moving = {"mean": mean, "variance": variance} args_moving = {"mean": mean, "variance": variance}
validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name) validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name)
return (input_x, scale, bias, input_x, input_x, input_x)
return (input_x, scale, bias, input_x, input_x)




class Conv2D(PrimitiveWithInfer): class Conv2D(PrimitiveWithInfer):


+ 1
- 1
tests/ut/python/ops/test_ops.py View File

@@ -671,7 +671,7 @@ test_case_nn_ops = [
'skip': []}), 'skip': []}),
('BatchNormGrad', { ('BatchNormGrad', {
'block': G.BatchNormGrad(), 'block': G.BatchNormGrad(),
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]],
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]],
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
'skip': ['backward']}), 'skip': ['backward']}),
('ApplyMomentum', { ('ApplyMomentum', {


Loading…
Cancel
Save