Browse Source

!12228 fix bugs of infer in lenet

From: @lianliguang
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
6b378aeb34
23 changed files with 322 additions and 94 deletions
  1. +4
    -0
      mindspore/core/abstract/infer_functions.h
  2. +17
    -0
      mindspore/core/abstract/prim_nn.cc
  3. +16
    -0
      mindspore/core/abstract/prim_others.cc
  4. +2
    -0
      mindspore/core/abstract/primitive_infer_map.cc
  5. +1
    -0
      mindspore/core/base/core_ops.h
  6. +12
    -15
      mindspore/core/ops/apply_momentum.cc
  7. +1
    -1
      mindspore/core/ops/assert.cc
  8. +42
    -0
      mindspore/core/ops/bias_add.cc
  9. +22
    -20
      mindspore/core/ops/conv2d.cc
  10. +19
    -1
      mindspore/core/ops/fill.cc
  11. +1
    -5
      mindspore/core/ops/gather.cc
  12. +0
    -4
      mindspore/core/ops/grad/conv2d_backprop_filter.cc
  13. +57
    -0
      mindspore/core/ops/grad/conv2d_backprop_input.cc
  14. +1
    -3
      mindspore/core/ops/grad/max_pool_grad.cc
  15. +75
    -0
      mindspore/core/ops/mat_mul.cc
  16. +18
    -12
      mindspore/core/ops/max_pool.cc
  17. +2
    -2
      mindspore/core/ops/merge.cc
  18. +9
    -5
      mindspore/core/ops/shape.cc
  19. +0
    -1
      mindspore/core/ops/zeros_like.cc
  20. +8
    -7
      mindspore/core/utils/check_convert_utils.cc
  21. +2
    -2
      mindspore/core/utils/check_convert_utils.h
  22. +3
    -16
      mindspore/core/utils/tensor_construct_utils.cc
  23. +10
    -0
      mindspore/core/utils/tensor_construct_utils.h

+ 4
- 0
mindspore/core/abstract/infer_functions.h View File

@@ -308,6 +308,10 @@ AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &pri
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T> template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict. // Inputs: a tuple or list or dict.


+ 17
- 0
mindspore/core/abstract/prim_nn.cc View File

@@ -202,6 +202,23 @@ AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const Primiti
return std::make_shared<AbstractTuple>(elements); return std::make_shared<AbstractTuple>(elements);
} }


AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
auto is_grad = GetValue<bool>(primitive->GetAttr("is_grad"));
CheckArgsSize(primitive->name(), args_spec_list, 2);
std::shared_ptr<BaseShape> shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{});
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
if (is_grad) {
shape = args_spec_list[0]->BuildShape();
}
auto type = args_spec_list[0]->BuildType();
MS_EXCEPTION_IF_NULL(type);
auto type_tensor = type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(type_tensor);
return std::make_shared<abstract::AbstractTensor>(type_tensor->element(), shape);
}

AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance).


+ 16
- 0
mindspore/core/abstract/prim_others.cc View File

@@ -559,5 +559,21 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con
ShapePtr shape = std::make_shared<Shape>(inferred_shape, min_shape, max_shape); ShapePtr shape = std::make_shared<Shape>(inferred_shape, min_shape, max_shape);
return std::make_shared<AbstractTensor>(input->element(), shape); return std::make_shared<AbstractTensor>(input->element(), shape);
} }

AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
auto type = args_spec_list[0]->BuildType();
MS_EXCEPTION_IF_NULL(type);
auto tensor_type = type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto value = tensor_type->element();
auto abstract = std::make_shared<abstract::AbstractType>(value);
abstract->set_value(value);
return abstract;
}
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

+ 2
- 0
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -193,6 +193,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
{prim::kPrimCast, {InferImplCast, true}}, {prim::kPrimCast, {InferImplCast, true}},
{prim::kPrimExpandDims, {InferImplExpandDims, true}}, {prim::kPrimExpandDims, {InferImplExpandDims, true}},
{prim::kPrimSparseSoftmaxCrossEntropyWithLogits, {InferImplSparseSoftmaxCrossEntropyWithLogits, true}},
{prim::kPrimDType, {InferImplDType, true}},
}; };
return prim_eval_implement_map; return prim_eval_implement_map;
} }


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -521,6 +521,7 @@ inline const PrimitivePtr kPrimTopKFusion = std::make_shared<Primitive>("TopKFus
inline const PrimitivePtr kPrimTileFusion = std::make_shared<Primitive>("TileFusion"); inline const PrimitivePtr kPrimTileFusion = std::make_shared<Primitive>("TileFusion");
inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion"); inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion");
inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("LayerNormFusion"); inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("LayerNormFusion");
inline const PrimitivePtr kPrimDType = std::make_shared<Primitive>("DType");


class DoSignaturePrimitive : public Primitive { class DoSignaturePrimitive : public Primitive {
public: public:


+ 12
- 15
mindspore/core/ops/apply_momentum.cc View File

@@ -56,32 +56,29 @@ float ApplyMomentum::get_gradient_scale() const {
AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto momentum_prim = primitive->cast<PrimApplyMomentumPtr>();
MS_EXCEPTION_IF_NULL(momentum_prim);
auto prim_name = momentum_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name); CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name);


// Infer shape // Infer shape
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name); auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name);


// Infer type // Infer type
auto v_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
auto a_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
auto l_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
auto g_type = input_args[3]->BuildType()->cast<TensorTypePtr>()->element();
auto m_type = input_args[4]->BuildType()->cast<TensorTypePtr>()->element();
auto v_tensor_type = input_args[0]->BuildType();
auto a_tensor_type = input_args[1]->BuildType();
auto l_type = input_args[2]->BuildType();
auto g_type = input_args[3]->BuildType();
auto m_type = input_args[4]->BuildType();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_type, valid_types, prim_name);
const std::set<TypePtr> valid_types_ptr = {TypeIdToType(kNumberTypeFloat16), TypeIdToType(kNumberTypeFloat32),
TypeIdToType(kNumberTypeFloat64)};
CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name);
std::map<std::string, TypePtr> args; std::map<std::string, TypePtr> args;
args.insert({"l_type", l_type}); args.insert({"l_type", l_type});
args.insert({"g_type", g_type}); args.insert({"g_type", g_type});
args.insert({"m_type", m_type}); args.insert({"m_type", m_type});
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types_ptr, prim_name);

return std::make_shared<abstract::AbstractTensor>(g_type, v_shape);
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types, prim_name);
auto g_type_tensor = g_type->cast<TensorTypePtr>();
auto element = g_type_tensor->element();
return std::make_shared<abstract::AbstractTensor>(element, v_shape);
} }
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyMomentum, prim::kPrimApplyMomentum, ApplyMomentumInfer); REGISTER_PRIMITIVE_EVAL_IMPL(ApplyMomentum, prim::kPrimApplyMomentum, ApplyMomentumInfer);
REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum); REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum);


+ 1
- 1
mindspore/core/ops/assert.cc View File

@@ -61,7 +61,7 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive
condition = input_args[0]->BuildType(); condition = input_args[0]->BuildType();
} }
std::vector<int64_t> output_shape = {1}; std::vector<int64_t> output_shape = {1};
std::set<TypePtr> local_bool = {TypeIdToType(kNumberTypeBool)};
std::set<TypeId> local_bool = {kNumberTypeBool};
std::map<std::string, TypePtr> args = {{"condition", condition}}; std::map<std::string, TypePtr> args = {{"condition", condition}};
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name); CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name);
auto inputs_type = input_args[1]->BuildType()->cast<TuplePtr>()->elements(); auto inputs_type = input_args[1]->BuildType()->cast<TuplePtr>()->elements();


+ 42
- 0
mindspore/core/ops/bias_add.cc View File

@@ -23,6 +23,42 @@


namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
// Add
namespace {
abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
CheckAndConvertUtils::CheckInteger("biasadd_infer", input_args.size(), kEqual, 2, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name);
CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name);
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
auto x_channel = x_shape[1];
if (format != NCHW) {
x_channel = x_shape[x_shape.size() - 1];
}
CheckAndConvertUtils::Check("b_shape[0]", b_shape[0], kEqual, "x_shape[1]", x_channel, prim_name);

std::vector<int64_t> out_shape = x_shape;
return std::make_shared<abstract::Shape>(out_shape);
}

TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
CheckAndConvertUtils::CheckInteger("biasadd_infer", input_args.size(), kEqual, 2, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> types;
types.emplace("input_x", input_args[0]->BuildType());
types.emplace("bias", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
}
} // namespace
void BiasAdd::set_format(const Format &format) { void BiasAdd::set_format(const Format &format) {
int64_t f = format; int64_t f = format;
this->AddAttr(kFormat, MakeValue(f)); this->AddAttr(kFormat, MakeValue(f));
@@ -32,7 +68,13 @@ Format BiasAdd::get_format() const {
return Format(GetValue<int64_t>(value_ptr)); return Format(GetValue<int64_t>(value_ptr));
} }
void BiasAdd::Init(const Format &format) { this->set_format(format); } void BiasAdd::Init(const Format &format) { this->set_format(format); }
AbstractBasePtr BiasAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(BiasAddInferType(primitive, input_args),
BiasAddInferShape(primitive, input_args));
}
// Add // Add
REGISTER_PRIMITIVE_EVAL_IMPL(BiasAdd, prim::kPrimBiasAdd, BiasAddInfer);
REGISTER_PRIMITIVE_C(kNameBiasAdd, BiasAdd); REGISTER_PRIMITIVE_C(kNameBiasAdd, BiasAdd);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

+ 22
- 20
mindspore/core/ops/conv2d.cc View File

@@ -30,32 +30,31 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto conv_prim = primitive->cast<PrimConv2dPtr>();
MS_EXCEPTION_IF_NULL(conv_prim);
auto prim_name = conv_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
if (conv_prim->get_format() == NHWC) {
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
} }


CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->get_group(), kEqual, "w_shape[1]",
w_shape[1], conv_prim->name());
auto out_channel = conv_prim->get_out_channel();
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name());
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual,
"w_shape[1]", w_shape[1], prim_name);
auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel));
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name);
std::vector<int64_t> temp_w; std::vector<int64_t> temp_w;
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
CheckAndConvertUtils::Check("kernel_size", conv_prim->get_kernel_size(), kEqual, "w_shape[2:4]", temp_w,
conv_prim->name());
CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual,
"w_shape[2:4]", temp_w, prim_name);


auto kernel_size_h = w_shape[2]; auto kernel_size_h = w_shape[2];
auto kernel_size_w = w_shape[3]; auto kernel_size_w = w_shape[3];
auto stride = conv_prim->get_stride();
auto dilation = conv_prim->get_dilation();
auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kDilation));
auto stride_h = stride[2]; auto stride_h = stride[2];
auto stride_w = stride[3]; auto stride_w = stride[3];
auto dilation_h = dilation[2]; auto dilation_h = dilation[2];
@@ -63,7 +62,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
int64_t h_out = -1; int64_t h_out = -1;
int64_t w_out = -1; int64_t w_out = -1;
std::vector<int64_t> pad_list(4, 0); std::vector<int64_t> pad_list(4, 0);
auto pad_mode = conv_prim->get_pad_mode();
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
if (pad_mode == VALID) { if (pad_mode == VALID) {
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
@@ -81,20 +80,23 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
pad_list.emplace_back(pad_left); pad_list.emplace_back(pad_left);
pad_list.emplace_back(pad_needed_h - pad_left); pad_list.emplace_back(pad_needed_h - pad_left);
} else if (pad_mode == PAD) { } else if (pad_mode == PAD) {
std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list));
auto pad_top = conv_prim->get_pad()[0];
auto pad_bottom = conv_prim->get_pad()[1];
auto pad_right = conv_prim->get_pad()[2];
auto pad_left = conv_prim->get_pad()[3];
auto pad = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
std::copy(pad.begin(), pad.end(), std::back_inserter(pad_list));
auto pad_top = pad[0];
auto pad_bottom = pad[1];
auto pad_right = pad[2];
auto pad_left = pad[3];


h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
h_out = floor(h_out); h_out = floor(h_out);
w_out = floor(w_out); w_out = floor(w_out);
} }
conv_prim->set_pad(pad_list);
CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name);
primitive->AddAttr(kPadList,
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name, true, true)));
std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out}; std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out};
if (conv_prim->get_format() == NHWC) {
if (format == NHWC) {
out_shape = {x_shape[0], h_out, w_out, out_channel}; out_shape = {x_shape[0], h_out, w_out, out_channel};
} }




+ 19
- 1
mindspore/core/ops/fill.cc View File

@@ -15,8 +15,10 @@
*/ */


#include "ops/fill.h" #include "ops/fill.h"
#include <memory>
#include "ops/op_utils.h" #include "ops/op_utils.h"
#include "utils/check_convert_utils.h" #include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"


namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
@@ -38,7 +40,23 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
valid_types.insert(kNumberTypeBool); valid_types.insert(kNumberTypeBool);
CheckAndConvertUtils::CheckTypeSame("output datatype", dtype, valid_types, prim_name); CheckAndConvertUtils::CheckTypeSame("output datatype", dtype, valid_types, prim_name);
auto out_shape = GetValue<std::vector<int64_t>>(input_args[1]->BuildValue()); auto out_shape = GetValue<std::vector<int64_t>>(input_args[1]->BuildValue());
return std::make_shared<abstract::AbstractTensor>(dtype, std::make_shared<abstract::Shape>(out_shape));
auto x_type = input_args[2]->BuildType();
auto x_type_id = x_type->type_id();
auto x_value = input_args[2]->BuildValue();
auto abs = std::make_shared<abstract::AbstractTensor>(dtype, std::make_shared<abstract::Shape>(out_shape));
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(x_type_id, out_shape);
auto mem_size = IntToSize(tensor->ElementsNum());
if (x_type_id == kNumberTypeInt) {
auto num = GetValue<int>(x_value);
SetTensorData(tensor->data_c(), num, mem_size);
} else if (x_type_id == kNumberTypeFloat || x_type_id == kNumberTypeFloat32) {
auto num = GetValue<float>(x_value);
SetTensorData(tensor->data_c(), num, mem_size);
} else {
MS_LOG(ERROR) << " Fill not supported to flod the constant type " << input_args[2]->ToString();
}
abs->set_value(tensor);
return abs;
} }
REGISTER_PRIMITIVE_EVAL_IMPL(Fill, prim::kPrimFill, FillInfer); REGISTER_PRIMITIVE_EVAL_IMPL(Fill, prim::kPrimFill, FillInfer);
REGISTER_PRIMITIVE_C(kNameFill, Fill); REGISTER_PRIMITIVE_C(kNameFill, Fill);


+ 1
- 5
mindspore/core/ops/gather.cc View File

@@ -23,15 +23,11 @@ namespace ops {
AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto gather_prim = primitive->cast<PrimGatherPtr>();
MS_EXCEPTION_IF_NULL(gather_prim);
auto prim_name = gather_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("gather_infer", input_args.size(), kEqual, 3, prim_name); CheckAndConvertUtils::CheckInteger("gather_infer", input_args.size(), kEqual, 3, prim_name);


// Infer type // Infer type
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
// auto dim_type = input_args[1]->BuildType();
// auto index_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
std::set<TypePtr> valid_x_type = {TypeIdToType(kObjectTypeTensorType)}; std::set<TypePtr> valid_x_type = {TypeIdToType(kObjectTypeTensorType)};
CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_x_type, prim_name);
const std::set<TypeId> valid_index_types = {kNumberTypeInt32, kNumberTypeInt64}; const std::set<TypeId> valid_index_types = {kNumberTypeInt32, kNumberTypeInt64};


+ 0
- 4
mindspore/core/ops/grad/conv2d_backprop_filter.cc View File

@@ -27,10 +27,6 @@ namespace {
abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive, abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto conv2d_backprop_filter_prim = primitive->cast<PrimConv2DBackpropFilterPtr>();
MS_EXCEPTION_IF_NULL(conv2d_backprop_filter_prim);
// auto prim_name = conv2d_backprop_filter_prim->name();

auto out_put = input_args[2]->BuildValue(); auto out_put = input_args[2]->BuildValue();
auto infer_shape = GetValue<std::vector<int64_t>>(out_put); auto infer_shape = GetValue<std::vector<int64_t>>(out_put);
return std::make_shared<abstract::Shape>(infer_shape); return std::make_shared<abstract::Shape>(infer_shape);


+ 57
- 0
mindspore/core/ops/grad/conv2d_backprop_input.cc View File

@@ -23,6 +23,62 @@


namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name);
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto doutput = input_args[0];
auto x_size = input_args[2];
auto x_size_value = x_size->GetValueTrack();
MS_EXCEPTION_IF_NULL(x_size);
auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value);
// infer dtype
auto dtype = doutput->BuildType();
if (!dtype->isa<TensorType>()) {
MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString();
}
auto input_tensor_type = dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input_tensor_type);
auto element = input_tensor_type->element();
// infer shape
auto dout_shape = doutput->BuildShape();
MS_EXCEPTION_IF_NULL(doutput);
auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>();
auto dout_shape_norm = dout_shapeptr->shape();
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
auto pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPadList));
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
if (std::all_of(pad_list.begin(), pad_list.end(), [](int64_t elem) -> bool { return elem != 0; })) {
primitive->AddAttr(kPadList, MakeValue(pad_list));
} else if (pad_mode == SAME) {
auto stride_h = stride[0];
auto stride_w = stride[1];
auto kernel_h = kernel_size[0];
auto kernel_w = kernel_size[1];
auto dilation_h = dilation[2];
auto dilation_w = dilation[3];
auto pad_needed_h = (dout_shape_norm[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2];
pad_needed_h = 0 > pad_needed_h ? 0 : pad_needed_h;
auto pad_top = pad_needed_h / 2;
auto pad_bottom = pad_needed_h - pad_top;
auto pad_needed_w = (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[2];
pad_needed_w = pad_needed_w > 0L ? pad_needed_w : 0L;
auto pad_left = pad_needed_w / 2;
auto pad_right = pad_needed_w - pad_left;
pad_list = {pad_top, pad_bottom, pad_left, pad_right};
} else if (pad_mode == PAD) {
pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
}
primitive->AddAttr(kPadList, MakeValue(pad_list));
return std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(x_size_v));
}

void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode,
const PadMode &pad_mode, const std::vector<int64_t> &pad, const PadMode &pad_mode, const std::vector<int64_t> &pad,
const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, int64_t group, const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, int64_t group,
@@ -140,6 +196,7 @@ std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const {
auto value_ptr = GetAttr(kPadList); auto value_ptr = GetAttr(kPadList);
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer);
REGISTER_PRIMITIVE_C(kNameConv2DBackpropInput, Conv2DBackpropInput); REGISTER_PRIMITIVE_C(kNameConv2DBackpropInput, Conv2DBackpropInput);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

+ 1
- 3
mindspore/core/ops/grad/max_pool_grad.cc View File

@@ -52,9 +52,7 @@ void MaxPoolGrad::set_strides(const std::vector<int64_t> &strides) {


AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
auto MaxPoolGrad_prim = primitive->cast<PrimMaxPoolGradPtr>();
MS_EXCEPTION_IF_NULL(MaxPoolGrad_prim);
auto op_name = MaxPoolGrad_prim->name();
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue());
auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name); auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name);
auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>();


+ 75
- 0
mindspore/core/ops/mat_mul.cc View File

@@ -21,6 +21,81 @@


namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
namespace {
abstract::ShapePtr MatMulInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("matmul_infer_input", input_args.size(), kEqual, 2, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
auto trans_a = GetValue<bool>(primitive->GetAttr(kTransposeA));
auto trans_b = GetValue<bool>(primitive->GetAttr(kTransposeB));

auto out_n = x_shape[0];
auto out_m = w_shape[1];
auto x_C = x_shape[1];
auto w_C = w_shape[0];

if (trans_a) {
out_n = x_shape[1];
x_C = x_shape[0];
}
if (trans_b) {
out_m = w_shape[0];
w_C = w_shape[1];
}
CheckAndConvertUtils::CheckInteger("dim C is not equal", x_C, kEqual, w_C, prim_name);
primitive->AddAttr("transpose_x1", MakeValue(trans_a));
primitive->AddAttr("transpose_x2", MakeValue(trans_b));
std::vector<int64_t> out_shape = {out_n, out_m};
return std::make_shared<abstract::Shape>(out_shape);
}

TypePtr MatMulInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64,
kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("w", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (infer_type == kNumberTypeInt8) {
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
}
return TypeIdToType(infer_type);
}
} // namespace

void MatMul::Init(bool transpose_a, bool transpose_b) {
set_transpose_a(transpose_a);
set_transpose_b(transpose_b);
}

void MatMul::set_transpose_a(bool transpose_a) { AddAttr(kTransposeA, MakeValue(transpose_a)); }

void MatMul::set_transpose_b(bool transpose_b) { AddAttr(kTransposeB, MakeValue(transpose_b)); }

bool MatMul::get_transpose_a() const {
auto value_ptr = GetAttr(kTransposeA);
return GetValue<bool>(value_ptr);
}

bool MatMul::get_transpose_b() const {
auto value_ptr = GetAttr(kTransposeB);
return GetValue<bool>(value_ptr);
}

// Add
AbstractBasePtr MatMulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(MatMulInferType(primitive, input_args),
MatMulInferShape(primitive, input_args)->shape());
}

// Add
REGISTER_PRIMITIVE_EVAL_IMPL(MatMul, prim::kPrimMatMul, MatMulInfer);
REGISTER_PRIMITIVE_C(kNameMatMul, MatMul); REGISTER_PRIMITIVE_C(kNameMatMul, MatMul);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

+ 18
- 12
mindspore/core/ops/max_pool.cc View File

@@ -94,22 +94,22 @@ void MaxPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<in
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto pool_prim = primitive->cast<PrimMaxPoolPtr>();
MS_EXCEPTION_IF_NULL(pool_prim);
auto op_name = pool_prim->name();
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
if (pool_prim->get_format() == NHWC) {
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
} }
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name);
auto kernel_size = pool_prim->get_kernel_size();
auto pad_mode = pool_prim->get_pad_mode();
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto pad_mode_value = (primitive->GetAttr(kPadMode));
PadMode pad_mode = PAD;
pad_mode = PadMode(GetValue<int64_t>(pad_mode_value));
auto batch = in_shape[0]; auto batch = in_shape[0];
auto channel = in_shape[1]; auto channel = in_shape[1];
auto in_h = in_shape[2]; auto in_h = in_shape[2];
auto in_w = in_shape[3]; auto in_w = in_shape[3];

auto strides = pool_prim->get_strides();
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
auto kernel_h = kernel_size[2]; auto kernel_h = kernel_size[2];
auto kernel_w = kernel_size[3]; auto kernel_w = kernel_size[3];
auto stride_h = strides[2]; auto stride_h = strides[2];
@@ -117,14 +117,14 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
int64_t out_h = -1; int64_t out_h = -1;
int64_t out_w = -1; int64_t out_w = -1;
if (pad_mode == VALID) { if (pad_mode == VALID) {
out_h = ceil((in_h - (kernel_h - 1)) / stride_h);
out_w = ceil((in_w - (kernel_w - 1)) / stride_w);
out_h = ceil((in_h - (kernel_h - 1) + stride_h - 1) / stride_h);
out_w = ceil((in_w - (kernel_w - 1) + stride_w - 1) / stride_w);
} else if (pad_mode == SAME) { } else if (pad_mode == SAME) {
out_h = ceil(in_h / stride_h); out_h = ceil(in_h / stride_h);
out_w = ceil(in_w / stride_w); out_w = ceil(in_w / stride_w);
} }
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
if (pool_prim->get_format() == NHWC) {
if (format == NHWC) {
out_shape = {batch, out_h, out_w, channel}; out_shape = {batch, out_h, out_w, channel};
} }
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
@@ -137,7 +137,13 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr"; MS_LOG(EXCEPTION) << "nullptr";
} }
return input_args[0]->BuildType();
auto input_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
auto input_tensor_type = input_type->cast<TensorTypePtr>();
if (input_tensor_type == nullptr) {
MS_LOG_EXCEPTION << "The maxpool's input must be a tensor but got " << input_type->ToString();
}
return input_tensor_type->element();
} }
} // namespace } // namespace




+ 2
- 2
mindspore/core/ops/merge.cc View File

@@ -38,9 +38,9 @@ AbstractBasePtr MergeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
for (int64_t i = 0; i != (int64_t)inputs_type.size(); i++) { for (int64_t i = 0; i != (int64_t)inputs_type.size(); i++) {
args.insert({"input[" + std::to_string(i) + "]", inputs_type[i]}); args.insert({"input[" + std::to_string(i) + "]", inputs_type[i]});
} }
std::set<TypePtr> template_type = {TypeIdToType(kNumberTypeBool)};
std::set<TypeId> template_type = {kNumberTypeBool};
for (auto item : common_valid_types) { for (auto item : common_valid_types) {
template_type.insert(TypeIdToType(item));
template_type.insert(item);
} }
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, template_type, op_name); CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, template_type, op_name);
std::vector<int64_t> in_shape0 = inputs_shape[0]->cast<abstract::ShapePtr>()->shape(); std::vector<int64_t> in_shape0 = inputs_shape[0]->cast<abstract::ShapePtr>()->shape();


+ 9
- 5
mindspore/core/ops/shape.cc View File

@@ -30,13 +30,17 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
// infer shape // infer shape
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto shape_prim = primitive->cast<PrimShapePtr>();
MS_EXCEPTION_IF_NULL(shape_prim);
auto op_name = shape_prim->name();
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
// infer type // infer type
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
return std::make_shared<abstract::AbstractTensor>(x_type, in_shape);
AbstractBasePtrList abs_list;
std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list),
[](int64_t item) -> std::shared_ptr<abstract::AbstractScalar> {
return std::make_shared<abstract::AbstractScalar>(item);
});
auto abs = std::make_shared<abstract::AbstractTuple>(abs_list);
abs->set_value(MakeValue(in_shape));
return abs;
} }
REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer); REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer);
REGISTER_PRIMITIVE_C(kNameShape, Shape); REGISTER_PRIMITIVE_C(kNameShape, Shape);


+ 0
- 1
mindspore/core/ops/zeros_like.cc View File

@@ -60,7 +60,6 @@ AbstractBasePtr ZerosLikeInfer(const abstract::AnalysisEnginePtr &, const Primit
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape()); InferShape(primitive, input_args)->shape());
} }
REGISTER_PRIMITIVE_EVAL_IMPL(ZerosLike, prim::kPrimZerosLike, ZerosLikeInfer);
REGISTER_PRIMITIVE_C(kNameZerosLike, ZerosLike); REGISTER_PRIMITIVE_C(kNameZerosLike, ZerosLike);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

+ 8
- 7
mindspore/core/utils/check_convert_utils.cc View File

@@ -486,7 +486,7 @@ void CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const Typ
} }


void CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args, void CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args,
const std::set<TypePtr> &valid_values,
const std::set<TypeId> &valid_values,
const std::string &prim_name, const bool allow_mix) { const std::string &prim_name, const bool allow_mix) {
std::vector<std::map<std::string, TypePtr>> check_results; std::vector<std::map<std::string, TypePtr>> check_results;
for (auto &iter : args) { for (auto &iter : args) {
@@ -502,7 +502,7 @@ void CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::stri
} }


std::map<std::string, TypePtr> CheckAndConvertUtils::_CheckArgumentType(const std::map<std::string, TypePtr> &arg, std::map<std::string, TypePtr> CheckAndConvertUtils::_CheckArgumentType(const std::map<std::string, TypePtr> &arg,
const std::set<TypePtr> &valid_values,
const std::set<TypeId> &valid_values,
const std::string &prim_name) { const std::string &prim_name) {
std::string arg_key = arg.begin()->first; std::string arg_key = arg.begin()->first;
TypePtr arg_val = arg.begin()->second; TypePtr arg_val = arg.begin()->second;
@@ -512,15 +512,16 @@ std::map<std::string, TypePtr> CheckAndConvertUtils::_CheckArgumentType(const st
arg_val = arg_val_->element(); arg_val = arg_val_->element();
} }


auto it = valid_values.find(arg_val);
auto it = valid_values.find(arg_val->type_id());
if (it == valid_values.end()) { if (it == valid_values.end()) {
std::ostringstream buffer; std::ostringstream buffer;
buffer << "For '" << prim_name << "' , the `" << arg_key << "` should be in { "; buffer << "For '" << prim_name << "' , the `" << arg_key << "` should be in { ";
for (auto valid_value : valid_values) { for (auto valid_value : valid_values) {
buffer << valid_value->ToString() << " },";
buffer << "but `" << arg_key << "`"
<< "is" << arg_val->ToString() << ".";
buffer << TypeIdToType(valid_value)->ToString() << ",";
} }
buffer << " },";
buffer << "but `" << arg_key << "`"
<< "is" << arg_val->ToString() << ".";
MS_EXCEPTION(TypeError) << buffer.str(); MS_EXCEPTION(TypeError) << buffer.str();
} }
return arg; return arg;
@@ -546,7 +547,7 @@ std::map<std::string, TypePtr> CheckAndConvertUtils::_CheckTypeSame(const std::m
except_flag = true; except_flag = true;
} }


if (except_flag || arg1_type != arg2_type) {
if (except_flag || arg1_type->type_id() != arg2_type->type_id()) {
std::ostringstream buffer; std::ostringstream buffer;
buffer << "For '" << prim_name << "'" buffer << "For '" << prim_name << "'"
<< "type of " << "type of "


+ 2
- 2
mindspore/core/utils/check_convert_utils.h View File

@@ -277,7 +277,7 @@ class CheckAndConvertUtils {
static void CheckSubClass(const std::string &type_name, const TypePtr type, const std::set<TypePtr> &template_types, static void CheckSubClass(const std::string &type_name, const TypePtr type, const std::set<TypePtr> &template_types,
const std::string &prim_name); const std::string &prim_name);
static void CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args, static void CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args,
const std::set<TypePtr> &valid_values, const std::string &prim_name,
const std::set<TypeId> &valid_values, const std::string &prim_name,
bool allow_mix = false); bool allow_mix = false);
static TypeId CheckTypeSame(const std::string &arg_name, const TypePtr arg_type, const std::set<TypeId> &valid_type, static TypeId CheckTypeSame(const std::string &arg_name, const TypePtr arg_type, const std::set<TypeId> &valid_type,
const std::string &prim_name); const std::string &prim_name);
@@ -291,7 +291,7 @@ class CheckAndConvertUtils {
private: private:
static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2); static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2);
static std::map<std::string, TypePtr> _CheckArgumentType(const std::map<std::string, TypePtr> &arg, static std::map<std::string, TypePtr> _CheckArgumentType(const std::map<std::string, TypePtr> &arg,
const std::set<TypePtr> &valid_values,
const std::set<TypeId> &valid_values,
const std::string &prim_name); const std::string &prim_name);
static std::map<std::string, TypePtr> _CheckTypeSame(const std::map<std::string, TypePtr> &arg1, static std::map<std::string, TypePtr> _CheckTypeSame(const std::map<std::string, TypePtr> &arg1,
const std::map<std::string, TypePtr> &arg2, const std::map<std::string, TypePtr> &arg2,


+ 3
- 16
mindspore/core/utils/tensor_construct_utils.cc View File

@@ -17,21 +17,8 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
namespace mindspore { namespace mindspore {
namespace {
template <typename T>
void SetTensorData(void *data, float num, size_t data_length) {
MS_EXCEPTION_IF_NULL(data);
auto tensor_data = reinterpret_cast<T *>(data);
MS_EXCEPTION_IF_NULL(tensor_data);
for (size_t index = 0; index < data_length; ++index) {
*tensor_data = num;
++tensor_data;
}
}
} // namespace
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape) { tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape);

size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum());
auto tensor_data = tensor->data_c(); auto tensor_data = tensor->data_c();
char *data = reinterpret_cast<char *>(tensor_data); char *data = reinterpret_cast<char *>(tensor_data);
@@ -43,11 +30,11 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std


tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape) { tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape);
auto mem_size = IntToSize(tensor->ElementsNum());
size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum());
if (tensor->data_type() == kNumberTypeFloat32) { if (tensor->data_type() == kNumberTypeFloat32) {
SetTensorData<float>(tensor->data_c(), 1.0, mem_size);
SetTensorData(tensor->data_c(), 1.0, mem_size);
} else if (tensor->data_type() == kNumberTypeInt) { } else if (tensor->data_type() == kNumberTypeInt) {
SetTensorData<int>(tensor->data_c(), 1, mem_size);
SetTensorData(tensor->data_c(), 1, mem_size);
} }
return tensor; return tensor;
} }


+ 10
- 0
mindspore/core/utils/tensor_construct_utils.h View File

@@ -18,6 +18,16 @@
#include <vector> #include <vector>
#include "ir/tensor.h" #include "ir/tensor.h"
namespace mindspore { namespace mindspore {
template <typename T>
void SetTensorData(void *data, T num, size_t data_length) {
MS_EXCEPTION_IF_NULL(data);
auto tensor_data = reinterpret_cast<T *>(data);
MS_EXCEPTION_IF_NULL(tensor_data);
for (size_t index = 0; index < data_length; ++index) {
*tensor_data = num;
++tensor_data;
}
}
class TensorConstructUtils { class TensorConstructUtils {
public: public:
static tensor::TensorPtr CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape); static tensor::TensorPtr CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape);


Loading…
Cancel
Save