Browse Source

add dynamic shape support to Zeroslike

tags/v1.1.0
TFbunny 5 years ago
parent
commit
7eced5d59e
3 changed files with 10 additions and 13 deletions
  1. +8
    -0
      mindspore/core/abstract/prim_arrays.cc
  2. +0
    -7
      mindspore/core/abstract/prim_nn.cc
  3. +2
    -6
      mindspore/ops/operations/array_ops.py

+ 8
- 0
mindspore/core/abstract/prim_arrays.cc View File

@@ -377,5 +377,13 @@ AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const Primitive

return tensor->ToAbstract();
}

AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor.
CheckArgsSize(primitive->name(), args_spec_list, 1);
return args_spec_list[0]->Broaden();
}

} // namespace abstract
} // namespace mindspore

+ 0
- 7
mindspore/core/abstract/prim_nn.cc View File

@@ -296,13 +296,6 @@ AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &pri
return args_spec_list[0]->Broaden();
}

AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor.
CheckArgsSize(primitive->name(), args_spec_list, 1);
return args_spec_list[0]->Broaden();
}

AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor.


+ 2
- 6
mindspore/ops/operations/array_ops.py View File

@@ -1034,7 +1034,7 @@ class OnesLike(PrimitiveWithInfer):
return x_dtype


class ZerosLike(PrimitiveWithInfer):
class ZerosLike(PrimitiveWithCheck):
"""
Creates a new tensor. All elements value are 0.

@@ -1059,12 +1059,8 @@ class ZerosLike(PrimitiveWithInfer):
"""Initialize ZerosLike"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])

def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_dtype):
def check_dtype(self, x_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype


class TupleToArray(PrimitiveWithInfer):


Loading…
Cancel
Save