|
|
@@ -26,6 +26,7 @@ import numbers |
|
|
import numpy as np |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from mindspore import log as logger |
|
|
from mindspore import log as logger |
|
|
|
|
|
from mindspore.common.initializer import Zero |
|
|
from .._utils import get_concat_offset |
|
|
from .._utils import get_concat_offset |
|
|
from ..operations.math_ops import _infer_shape_reduce |
|
|
from ..operations.math_ops import _infer_shape_reduce |
|
|
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op |
|
|
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op |
|
|
@@ -38,6 +39,7 @@ from ...common.parameter import Parameter |
|
|
from ...common.tensor import Tensor |
|
|
from ...common.tensor import Tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ScatterOp(PrimitiveWithInfer): |
|
|
class _ScatterOp(PrimitiveWithInfer): |
|
|
""" |
|
|
""" |
|
|
Defines Scatter operators |
|
|
Defines Scatter operators |
|
|
@@ -3015,8 +3017,13 @@ class StridedSlice(PrimitiveWithInfer): |
|
|
|
|
|
|
|
|
ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v) |
|
|
ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v) |
|
|
|
|
|
|
|
|
value = None if all(ret_shape) else Tensor(np.array([]).reshape(ret_shape), x['dtype'].element_type(), |
|
|
|
|
|
check_zero_dims=False) |
|
|
|
|
|
|
|
|
if all(ret_shape): |
|
|
|
|
|
value = None |
|
|
|
|
|
else: |
|
|
|
|
|
init_func = Zero() |
|
|
|
|
|
init_func.__enable_zero_dim__ = True |
|
|
|
|
|
value = Tensor(dtype=x['dtype'].element_type(), shape=ret_shape, init=init_func) |
|
|
|
|
|
|
|
|
if "max_value" in x and "min_value" in x: |
|
|
if "max_value" in x and "min_value" in x: |
|
|
validator.check_value_type("min_value", x["min_value"], [tuple, list], self.name) |
|
|
validator.check_value_type("min_value", x["min_value"], [tuple, list], self.name) |
|
|
validator.check_value_type("max_value", x["max_value"], [tuple, list], self.name) |
|
|
validator.check_value_type("max_value", x["max_value"], [tuple, list], self.name) |
|
|
|