Browse Source

Fix bug in strided slice MS interface

tags/v0.6.0-beta
peixu_ren 5 years ago
parent
commit
d68e44fa93
3 changed files with 13 additions and 11 deletions
  1. +2
    -0
      mindspore/ops/_grad/grad_inner_ops.py
  2. +5
    -5
      mindspore/ops/_op_impl/aicpu/strided_slice.py
  3. +6
    -6
      mindspore/ops/_op_impl/aicpu/strided_slice_grad.py

+ 2
- 0
mindspore/ops/_grad/grad_inner_ops.py View File

@@ -15,6 +15,7 @@

"""array_ops"""

from .. import operations as P
from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
from ..composite.multitype_ops.zeros_like_impl import zeros_like
@@ -24,6 +25,7 @@ from .grad_base import bprop_getters
@bprop_getters.register(inner.StridedSliceAICPU)
def get_bprop_strided_slice_aicpu(self):
"""Generate bprop for StridedSlice"""
shape_op = P.Shape()
input_grad = G.StridedSliceGradAICPU(self.begin_mask,
self.end_mask,
self.ellipsis_mask,


+ 5
- 5
mindspore/ops/_op_impl/aicpu/strided_slice.py View File

@@ -28,11 +28,11 @@ strided_slice_op_info = AiCPURegOp("StridedSliceAICPU") \
.attr("ellipsis_mask", "int") \
.attr("new_axis_mask", "int") \
.attr("shrink_axis_mask", "int") \
.dtype_format(DataType.F32_NCHW,
DataType.I32_NCHW,
DataType.I32_NCHW,
DataType.I32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_Default,
DataType.I32_Default,
DataType.I32_Default,
DataType.I32_Default,
DataType.F32_Default) \
.get_op_info()

@op_info_register(strided_slice_op_info)


+ 6
- 6
mindspore/ops/_op_impl/aicpu/strided_slice_grad.py View File

@@ -29,12 +29,12 @@ strided_slice_grad_op_info = AiCPURegOp("StridedSliceGradAICPU") \
.attr("ellipsis_mask", "int") \
.attr("new_axis_mask", "int") \
.attr("shrink_axis_mask", "int") \
.dtype_format(DataType.F32_NCHW,
DataType.I32_NCHW,
DataType.I32_NCHW,
DataType.I32_NCHW,
DataType.I32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_Default,
DataType.I32_Default,
DataType.I32_Default,
DataType.I32_Default,
DataType.I32_Default,
DataType.F32_Default) \
.get_op_info()

@op_info_register(strided_slice_grad_op_info)


Loading…
Cancel
Save