| @@ -30,11 +30,10 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| } | } | ||||
| @@ -26,11 +26,10 @@ abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::ve | |||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| // infer shape | // infer shape | ||||
| auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShape("var_shape", input_args[0]->GetShapeTrack(), prim_name); | |||||
| auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShape("m_shape", input_args[1]->GetShapeTrack(), prim_name); | |||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[2]->GetShapeTrack(), prim_name); | |||||
| auto grad_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("grad_shape", input_args[9]->GetShapeTrack(), prim_name); | |||||
| auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape]; | |||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape]; | |||||
| auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[9]->GetShapeTrack())[kShape]; | |||||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name); | CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name); | ||||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name); | CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name); | ||||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name); | CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name); | ||||
| @@ -38,15 +38,13 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt | |||||
| CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name); | ||||
| auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>(); | auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(element0); | MS_EXCEPTION_IF_NULL(element0); | ||||
| auto element0_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name); | |||||
| auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape]; | |||||
| std::map<std::string, TypePtr> types; | std::map<std::string, TypePtr> types; | ||||
| types.emplace("element0", element0->BuildType()); | types.emplace("element0", element0->BuildType()); | ||||
| for (size_t i = 1; i < elements.size(); ++i) { | for (size_t i = 1; i < elements.size(); ++i) { | ||||
| std::string elementi = "element" + std::to_string(i); | std::string elementi = "element" + std::to_string(i); | ||||
| auto elementi_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name); | |||||
| auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(), | CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(), | ||||
| prim_name); | prim_name); | ||||
| for (size_t j = 0; j < element0_shape.size(); ++j) { | for (size_t j = 0; j < element0_shape.size(); ++j) { | ||||
| @@ -60,7 +60,7 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr | |||||
| CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name); | CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name); | ||||
| // Infer shape | // Infer shape | ||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| // Infer type | // Infer type | ||||
| auto v_tensor_type = input_args[0]->BuildType(); | auto v_tensor_type = input_args[0]->BuildType(); | ||||
| @@ -23,7 +23,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto x_rank = SizeToLong(x_shape.size()); | auto x_rank = SizeToLong(x_shape.size()); | ||||
| CheckAndConvertUtils::CheckInRange<int64_t>("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | CheckAndConvertUtils::CheckInRange<int64_t>("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | ||||
| axis = axis < 0 ? axis + x_rank : axis; | axis = axis < 0 ? axis + x_rank : axis; | ||||
| @@ -42,7 +42,7 @@ AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const Primitive | |||||
| // Infer shape | // Infer shape | ||||
| auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto x_rank = SizeToLong(x_shape.size()); | auto x_rank = SizeToLong(x_shape.size()); | ||||
| CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | ||||
| if (axis < 0) { | if (axis < 0) { | ||||
| @@ -29,7 +29,7 @@ AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt | |||||
| CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name); | ||||
| // Infer Shape | // Infer Shape | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto infer_shape = std::make_shared<abstract::Shape>(x_shape); | auto infer_shape = std::make_shared<abstract::Shape>(x_shape); | ||||
| // Infer Type | // Infer Type | ||||
| @@ -47,8 +47,7 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive | |||||
| } | } | ||||
| condition = TypeIdToType(kNumberTypeBool); | condition = TypeIdToType(kNumberTypeBool); | ||||
| } else { | } else { | ||||
| auto condition_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto condition_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("condition's rank", condition_shape[0], kLessEqual, 1, op_name); | CheckAndConvertUtils::CheckInteger("condition's rank", condition_shape[0], kLessEqual, 1, op_name); | ||||
| if (condition_shape[0] == 1) { | if (condition_shape[0] == 1) { | ||||
| auto condition_value = reinterpret_cast<bool *>(input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c()); | auto condition_value = reinterpret_cast<bool *>(input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c()); | ||||
| @@ -25,9 +25,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| auto value_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("value_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(value_shape); | return std::make_shared<abstract::Shape>(value_shape); | ||||
| } | } | ||||
| @@ -27,7 +27,7 @@ AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt | |||||
| CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name); | ||||
| // Infer Shape | // Infer Shape | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto infer_shape = std::make_shared<abstract::Shape>(x_shape); | auto infer_shape = std::make_shared<abstract::Shape>(x_shape); | ||||
| // Infer Type | // Infer Type | ||||
| @@ -30,9 +30,7 @@ namespace { | |||||
| abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| auto input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| if (input_shape.size() != 2) { | if (input_shape.size() != 2) { | ||||
| MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions"; | MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions"; | ||||
| } | } | ||||
| @@ -82,7 +82,7 @@ namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | auto op_name = primitive->name(); | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | ||||
| if (format == NHWC) { | if (format == NHWC) { | ||||
| in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | ||||
| @@ -75,20 +75,19 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit | |||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name); | CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name); | ||||
| auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name); | |||||
| auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | ||||
| if (format == NHWC) { | if (format == NHWC) { | ||||
| input_x = {input_x[0], input_x[3], input_x[1], input_x[2]}; | input_x = {input_x[0], input_x[3], input_x[1], input_x[2]}; | ||||
| } | } | ||||
| auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name); | |||||
| auto bias = CheckAndConvertUtils::ConvertShapePtrToShape("bias", input_args[2]->BuildShape(), prim_name); | |||||
| auto mean = CheckAndConvertUtils::ConvertShapePtrToShape("mean", input_args[3]->BuildShape(), prim_name); | |||||
| auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name); | |||||
| auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; | |||||
| auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape]; | |||||
| auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape]; | |||||
| std::vector<int64_t> input_shape_norm; | std::vector<int64_t> input_shape_norm; | ||||
| if (format == NCHW) { | if (format == NCHW) { | ||||
| input_shape_norm = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | |||||
| input_shape_norm = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| } else { | } else { | ||||
| input_shape_norm.push_back(input_x[0]); | input_shape_norm.push_back(input_x[0]); | ||||
| input_shape_norm.push_back(input_x[3]); | input_shape_norm.push_back(input_x[3]); | ||||
| @@ -68,12 +68,10 @@ AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const Pr | |||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | auto op_name = primitive->name(); | ||||
| auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShape("mean_shape", input_args[1]->BuildShape(), op_name); | |||||
| auto variance_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("variance_shape", input_args[2]->BuildShape(), op_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto global_step_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("global_step_shape", input_args[3]->BuildShape(), op_name); | |||||
| auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| auto variance_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto global_step_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::Check("mean_shape", mean_shape, kEqual, "gamma_shape", variance_shape, op_name); | CheckAndConvertUtils::Check("mean_shape", mean_shape, kEqual, "gamma_shape", variance_shape, op_name); | ||||
| CheckAndConvertUtils::Check("mean_shape[0]", mean_shape[0], kEqual, "input channel", x_shape[1], op_name); | CheckAndConvertUtils::Check("mean_shape[0]", mean_shape[0], kEqual, "input channel", x_shape[1], op_name); | ||||
| CheckAndConvertUtils::CheckInteger("global step shape len", global_step_shape.size(), kEqual, 1, op_name); | CheckAndConvertUtils::CheckInteger("global step shape len", global_step_shape.size(), kEqual, 1, op_name); | ||||
| @@ -55,7 +55,7 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri | |||||
| (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, | (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, | ||||
| prim_name); | prim_name); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | ||||
| auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize)); | auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize)); | ||||
| auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops)); | auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops)); | ||||
| @@ -29,7 +29,7 @@ namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); | ||||
| auto out_shape = x_shape; | auto out_shape = x_shape; | ||||
| int64_t block_shape_prod = 1; | int64_t block_shape_prod = 1; | ||||
| @@ -30,8 +30,8 @@ abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::v | |||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| // check | // check | ||||
| CheckAndConvertUtils::CheckInteger("arg size", input_args.size(), kEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("arg size", input_args.size(), kEqual, 2, prim_name); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name); | ||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | ||||
| @@ -34,10 +34,9 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive, | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto weight_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name); | CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name); | ||||
| std::vector<int64_t> infer_shape; | std::vector<int64_t> infer_shape; | ||||
| if (weight_shape.size() < 1) { | if (weight_shape.size() < 1) { | ||||
| @@ -50,7 +50,7 @@ AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const Primit | |||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| // infer shape | // infer shape | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| // infer type | // infer type | ||||
| auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | ||||
| std::vector<TypePtr> output_types; | std::vector<TypePtr> output_types; | ||||
| @@ -24,7 +24,7 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto value_ptr = primitive->GetAttr(kShape); | auto value_ptr = primitive->GetAttr(kShape); | ||||
| auto input_x = GetValue<std::vector<int64_t>>(value_ptr); | auto input_x = GetValue<std::vector<int64_t>>(value_ptr); | ||||
| int64_t outer_dim_offset = input_x.size() - x_shape.size(); | int64_t outer_dim_offset = input_x.size() - x_shape.size(); | ||||
| @@ -31,7 +31,7 @@ AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Ceil"); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; | const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; | ||||
| auto infer_type = input_args[0]->BuildType(); | auto infer_type = input_args[0]->BuildType(); | ||||
| auto data_type = CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name()); | auto data_type = CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name()); | ||||
| @@ -43,8 +43,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive | |||||
| CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name); | ||||
| auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>(); | auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(element0); | MS_EXCEPTION_IF_NULL(element0); | ||||
| auto element0_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name); | |||||
| auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape]; | |||||
| auto element0_rank = SizeToLong(element0_shape.size()); | auto element0_rank = SizeToLong(element0_shape.size()); | ||||
| auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | ||||
| CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank}, | CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank}, | ||||
| @@ -56,8 +55,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive | |||||
| int64_t all_shp = element0_shape[axis]; | int64_t all_shp = element0_shape[axis]; | ||||
| for (size_t i = 1; i < elements.size(); ++i) { | for (size_t i = 1; i < elements.size(); ++i) { | ||||
| std::string elementi = "element" + std::to_string(i); | std::string elementi = "element" + std::to_string(i); | ||||
| auto elementi_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name); | |||||
| auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(), | CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(), | ||||
| prim_name); | prim_name); | ||||
| for (int64_t j = 0; j < element0_rank; ++j) { | for (int64_t j = 0; j < element0_rank; ++j) { | ||||
| @@ -24,8 +24,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kEqual, 1, "ConstantOfShape"); | CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kEqual, 1, "ConstantOfShape"); | ||||
| auto input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "ConstantOfShape"); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(input_shape); | return std::make_shared<abstract::Shape>(input_shape); | ||||
| } | } | ||||
| @@ -79,8 +79,8 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | ||||
| if (format == NHWC) { | if (format == NHWC) { | ||||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | ||||
| @@ -28,8 +28,7 @@ namespace { | |||||
| abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[3]->BuildShape(), prim_name); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(input_shape); | return std::make_shared<abstract::Shape>(input_shape); | ||||
| } | } | ||||
| @@ -24,11 +24,10 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| } | } | ||||
| @@ -49,7 +49,7 @@ AbstractBasePtr CropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt | |||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| // infer shape | // infer shape | ||||
| auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| // infer type | // infer type | ||||
| auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | ||||
| return std::make_shared<abstract::AbstractTensor>(x_type, out_shape); | return std::make_shared<abstract::AbstractTensor>(x_type, out_shape); | ||||
| @@ -24,18 +24,14 @@ namespace ops { | |||||
| AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | MS_EXCEPTION_IF_NULL(input_args[0]); | ||||
| // auto input = input_args[0]; | |||||
| // Infer type | // Infer type | ||||
| auto output0_type = kInt32; | auto output0_type = kInt32; | ||||
| auto output1_type = kFloat32; | auto output1_type = kFloat32; | ||||
| // Infer shape | // Infer shape | ||||
| std::vector<int64_t> out_shape; | std::vector<int64_t> out_shape; | ||||
| auto input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto string_num = input_shape[0]; | auto string_num = input_shape[0]; | ||||
| if (string_num == 0) { | if (string_num == 0) { | ||||
| out_shape.push_back(1); | out_shape.push_back(1); | ||||
| @@ -54,7 +54,7 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri | |||||
| auto input_x = input_args[0]->cast<abstract::AbstractTensorPtr>(); | auto input_x = input_args[0]->cast<abstract::AbstractTensorPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(input_x); | MS_EXCEPTION_IF_NULL(input_x); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | ||||
| if (format == NHWC) { | if (format == NHWC) { | ||||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | ||||
| @@ -119,8 +119,8 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive, | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| CheckAndConvertUtils::CheckInRange<size_t>("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | CheckAndConvertUtils::CheckInRange<size_t>("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape]; | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | ||||
| if (format == NHWC) { | if (format == NHWC) { | ||||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | ||||
| @@ -120,9 +120,9 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c | |||||
| auto boxes = input_args[0]; | auto boxes = input_args[0]; | ||||
| auto scores = input_args[1]; | auto scores = input_args[1]; | ||||
| auto anchors = input_args[2]; | auto anchors = input_args[2]; | ||||
| auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShape("boxes_shape", boxes->BuildShape(), prim_name); | |||||
| auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShape("scores_shape", scores->BuildShape(), prim_name); | |||||
| auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShape("anchors_shape", anchors->BuildShape(), prim_name); | |||||
| auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(boxes->BuildShape())[kShape]; | |||||
| auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(scores->BuildShape())[kShape]; | |||||
| auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(anchors->BuildShape())[kShape]; | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | ||||
| if (format == NHWC) { | if (format == NHWC) { | ||||
| boxes_shape = {boxes_shape[0], boxes_shape[3], boxes_shape[1], boxes_shape[2]}; | boxes_shape = {boxes_shape[0], boxes_shape[3], boxes_shape[1], boxes_shape[2]}; | ||||
| @@ -43,7 +43,7 @@ AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const Primitiv | |||||
| CheckAndConvertUtils::CheckInteger("dropout_infer", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("dropout_infer", input_args.size(), kEqual, 1, prim_name); | ||||
| // Infer shape | // Infer shape | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterEqual, 1, prim_name); | ||||
| std::vector<int64_t> out_shape; | std::vector<int64_t> out_shape; | ||||
| out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end()); | out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end()); | ||||
| @@ -31,11 +31,10 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| } | } | ||||
| @@ -36,7 +36,7 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi | |||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| // Infer shape | // Infer shape | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto dim_val = GetValue<int64_t>(input_args[1]->BuildValue()); | auto dim_val = GetValue<int64_t>(input_args[1]->BuildValue()); | ||||
| auto rank = x_shape.size(); | auto rank = x_shape.size(); | ||||
| CheckAndConvertUtils::CheckInRange<int64_t>("axis", dim_val, kIncludeBoth, {-rank - 1, rank}, prim_name); | CheckAndConvertUtils::CheckInRange<int64_t>("axis", dim_val, kIncludeBoth, {-rank - 1, rank}, prim_name); | ||||
| @@ -29,9 +29,9 @@ namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), prim_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kGreaterEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kGreaterEqual, 1, prim_name); | ||||
| CheckAndConvertUtils::Check("min_shape", min_shape, kEqual, "max_shape", max_shape, prim_name); | CheckAndConvertUtils::Check("min_shape", min_shape, kEqual, "max_shape", max_shape, prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("min_shape", min_shape.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("min_shape", min_shape.size(), kEqual, 1, prim_name); | ||||
| @@ -44,9 +44,9 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE | |||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | auto op_name = primitive->name(); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), op_name); | |||||
| auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), op_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("x rank", (int64_t)x_shape.size(), kGreaterThan, 1, op_name); | CheckAndConvertUtils::CheckInteger("x rank", (int64_t)x_shape.size(), kGreaterThan, 1, op_name); | ||||
| CheckAndConvertUtils::Check("min shape", min_shape, kEqual, "max shape", max_shape, op_name); | CheckAndConvertUtils::Check("min shape", min_shape, kEqual, "max shape", max_shape, op_name); | ||||
| CheckAndConvertUtils::CheckInteger("min shape", (int64_t)min_shape.size(), kEqual, 1, op_name); | CheckAndConvertUtils::CheckInteger("min shape", (int64_t)min_shape.size(), kEqual, 1, op_name); | ||||
| @@ -24,8 +24,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| in_shape.pop_back(); | in_shape.pop_back(); | ||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| } | } | ||||
| @@ -33,7 +33,7 @@ AbstractBasePtr FftRealInfer(const abstract::AnalysisEnginePtr &, const Primitiv | |||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| auto out_dtype = kFloat32; | auto out_dtype = kFloat32; | ||||
| auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| out_shape.pop_back(); | out_shape.pop_back(); | ||||
| return std::make_shared<abstract::AbstractTensor>(out_dtype, std::make_shared<abstract::Shape>(out_shape)); | return std::make_shared<abstract::AbstractTensor>(out_dtype, std::make_shared<abstract::Shape>(out_shape)); | ||||
| } | } | ||||
| @@ -25,7 +25,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kGreaterEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kGreaterEqual, 1, prim_name); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto prod = 1; | auto prod = 1; | ||||
| int64_t size = x_shape.size(); | int64_t size = x_shape.size(); | ||||
| for (int64_t i = 1; i < size; i++) { | for (int64_t i = 1; i < size; i++) { | ||||
| @@ -28,11 +28,10 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| } | } | ||||
| @@ -53,7 +53,7 @@ namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | auto op_name = primitive->name(); | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | ||||
| if (format == NHWC) { | if (format == NHWC) { | ||||
| in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | ||||
| @@ -53,8 +53,8 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P | |||||
| MS_EXCEPTION_IF_NULL(input_args[1]); | MS_EXCEPTION_IF_NULL(input_args[1]); | ||||
| auto input0 = input_args[0]; | auto input0 = input_args[0]; | ||||
| auto input1 = input_args[1]; | auto input1 = input_args[1]; | ||||
| auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input0->BuildShape(), prim_name); | |||||
| auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input1->BuildShape(), prim_name); | |||||
| auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input0->BuildShape())[kShape]; | |||||
| auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input1->BuildShape())[kShape]; | |||||
| auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | ||||
| auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias)); | auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias)); | ||||
| if (has_bias) { | if (has_bias) { | ||||
| @@ -78,8 +78,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P | |||||
| new_k = input1_shape[1]; | new_k = input1_shape[1]; | ||||
| } | } | ||||
| if (has_bias) { | if (has_bias) { | ||||
| auto input2_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), prim_name); | |||||
| auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; | |||||
| if (input2_shape[0] != input1_shape[0]) { | if (input2_shape[0] != input1_shape[0]) { | ||||
| MS_EXCEPTION(ValueError) << "Bias size invalid"; | MS_EXCEPTION(ValueError) << "Bias size invalid"; | ||||
| } | } | ||||
| @@ -53,7 +53,7 @@ namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | auto op_name = primitive->name(); | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | ||||
| if (format == NHWC) { | if (format == NHWC) { | ||||
| in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | ||||
| @@ -33,7 +33,7 @@ AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const Prim | |||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | auto op_name = primitive->name(); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto x_shape_len = (int64_t)x_shape.size(); | auto x_shape_len = (int64_t)x_shape.size(); | ||||
| auto begin_v = input_args[1]->BuildValue(); | auto begin_v = input_args[1]->BuildValue(); | ||||
| auto size_v = input_args[2]->BuildValue(); | auto size_v = input_args[2]->BuildValue(); | ||||
| @@ -29,8 +29,8 @@ abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::v | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| // check | // check | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; | |||||
| int64_t x_rank = x_shape.size(); | int64_t x_rank = x_shape.size(); | ||||
| CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name); | CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name); | ||||
| auto dim_v = GetValue<int64_t>(input_args[1]->BuildValue()); | auto dim_v = GetValue<int64_t>(input_args[1]->BuildValue()); | ||||
| @@ -32,9 +32,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto indices_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("indices_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| auto input_rank = input_shape.size(); | auto input_rank = input_shape.size(); | ||||
| auto indices_rank = indices_shape.size(); | auto indices_rank = indices_shape.size(); | ||||
| CheckAndConvertUtils::CheckInteger("Input of indices data", input_rank, kGreaterEqual, | CheckAndConvertUtils::CheckInteger("Input of indices data", input_rank, kGreaterEqual, | ||||
| @@ -28,8 +28,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(input_shape); | return std::make_shared<abstract::Shape>(input_shape); | ||||
| } | } | ||||
| @@ -47,13 +47,11 @@ bool BatchNormGrad::get_is_training() const { | |||||
| AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[1]); | MS_EXCEPTION_IF_NULL(input_args[1]); | ||||
| MS_EXCEPTION_IF_NULL(input_args[2]); | MS_EXCEPTION_IF_NULL(input_args[2]); | ||||
| MS_EXCEPTION_IF_NULL(input_args[3]); | MS_EXCEPTION_IF_NULL(input_args[3]); | ||||
| auto y_backprop_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("y_backprop_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->BuildShape(), op_name); | |||||
| auto y_backprop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape); | CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape); | ||||
| auto dx = input_args[1]->Broaden(); | auto dx = input_args[1]->Broaden(); | ||||
| @@ -46,7 +46,7 @@ AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const Prim | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | MS_EXCEPTION_IF_NULL(input_args[0]); | ||||
| // Infer shape | // Infer shape | ||||
| auto inshape = CheckAndConvertUtils::ConvertShapePtrToShape("inshape", input_args[0]->BuildShape(), prim_name); | |||||
| auto inshape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| for (size_t i = 0; i < inshape.size() - 1; i++) { | for (size_t i = 0; i < inshape.size() - 1; i++) { | ||||
| inshape[i] = 1; | inshape[i] = 1; | ||||
| } | } | ||||
| @@ -27,10 +27,9 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive | |||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto weight_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name); | CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name); | ||||
| if (weight_shape.size() < 1) { | if (weight_shape.size() < 1) { | ||||
| CheckAndConvertUtils::Check("y shape", y_shape, kEqual, "weight shape", weight_shape, prim_name); | CheckAndConvertUtils::Check("y shape", y_shape, kEqual, "weight shape", weight_shape, prim_name); | ||||
| @@ -35,8 +35,7 @@ namespace { | |||||
| abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| } | } | ||||
| @@ -21,10 +21,8 @@ namespace mindspore { | |||||
| namespace ops { | namespace ops { | ||||
| AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto op_name = primitive->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); | MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); | ||||
| auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); | auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(tensor_type); | MS_EXCEPTION_IF_NULL(tensor_type); | ||||
| auto element = tensor_type->element(); | auto element = tensor_type->element(); | ||||
| @@ -35,9 +35,9 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE | |||||
| prim_name); | prim_name); | ||||
| // Infer Shape | // Infer Shape | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dout_shape", input_args[2]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError); | CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError); | ||||
| CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError); | CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError); | ||||
| @@ -40,9 +40,9 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const | |||||
| CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", input_args.size(), kEqual, 3, prim_name); | CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", input_args.size(), kEqual, 3, prim_name); | ||||
| // Infer shape | // Infer shape | ||||
| auto prediction = CheckAndConvertUtils::ConvertShapePtrToShape("prediction", input_args[0]->BuildShape(), prim_name); | |||||
| auto target = CheckAndConvertUtils::ConvertShapePtrToShape("target", input_args[1]->BuildShape(), prim_name); | |||||
| auto dloss = CheckAndConvertUtils::ConvertShapePtrToShape("dloss", input_args[2]->BuildShape(), prim_name); | |||||
| auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| auto dloss = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError); | CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError); | ||||
| CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError); | CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError); | ||||
| @@ -27,9 +27,8 @@ AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const | |||||
| for (auto input : input_args) { | for (auto input : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(input); | MS_EXCEPTION_IF_NULL(input); | ||||
| } | } | ||||
| auto op_name = primitive->name(); | |||||
| std::vector<int64_t> hits_shape; | std::vector<int64_t> hits_shape; | ||||
| auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto input = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| hits_shape.push_back(input[0]); | hits_shape.push_back(input[0]); | ||||
| auto value_type = input_args[2]->BuildType(); | auto value_type = input_args[2]->BuildType(); | ||||
| @@ -46,7 +46,7 @@ AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const Prim | |||||
| } | } | ||||
| const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; | const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; | ||||
| (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); | (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto x_rank = SizeToLong(x_shape.size()); | auto x_rank = SizeToLong(x_shape.size()); | ||||
| auto axiss = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis)); | auto axiss = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis)); | ||||
| for (auto &axis : axiss) { | for (auto &axis : axiss) { | ||||
| @@ -24,7 +24,7 @@ namespace mindspore { | |||||
| namespace ops { | namespace ops { | ||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Log"); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(x_shape); | return std::make_shared<abstract::Shape>(x_shape); | ||||
| } | } | ||||
| @@ -24,8 +24,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr LogicalNotInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr LogicalNotInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| } | } | ||||
| @@ -78,7 +78,7 @@ namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 4, prim_name); | ||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| } | } | ||||
| @@ -32,14 +32,14 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr | |||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | auto op_name = primitive->name(); | ||||
| auto input0 = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto input1 = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name); | |||||
| auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("input0_shape", input0.size(), kEqual, 2, op_name); | CheckAndConvertUtils::CheckInteger("input0_shape", input0.size(), kEqual, 2, op_name); | ||||
| CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name); | CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name); | ||||
| CheckAndConvertUtils::CheckInteger("input1_shape", input1.size(), kGreaterEqual, 1, op_name); | CheckAndConvertUtils::CheckInteger("input1_shape", input1.size(), kGreaterEqual, 1, op_name); | ||||
| if (input_args.size() == 3) { | if (input_args.size() == 3) { | ||||
| auto input2 = CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), op_name); | |||||
| auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("input2_shape", input2.size(), kEqual, 1, op_name); | CheckAndConvertUtils::CheckInteger("input2_shape", input2.size(), kEqual, 1, op_name); | ||||
| CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name); | CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name); | ||||
| } | } | ||||
| @@ -32,9 +32,9 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<Abstr | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| CheckAndConvertUtils::CheckInteger("lstm_prim_infer", input_args.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("lstm_prim_infer", input_args.size(), kEqual, 4, prim_name); | ||||
| auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("h_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("c_shape", input_args[2]->BuildShape(), prim_name); | |||||
| auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; | |||||
| int64_t input_x_size = GetValue<int64_t>(primitive->GetAttr(kInput_size)); | int64_t input_x_size = GetValue<int64_t>(primitive->GetAttr(kInput_size)); | ||||
| CheckAndConvertUtils::CheckInteger("x_shape.size()", x_input_shape.size(), kEqual, 3, prim_name); | CheckAndConvertUtils::CheckInteger("x_shape.size()", x_input_shape.size(), kEqual, 3, prim_name); | ||||
| @@ -26,8 +26,8 @@ abstract::ShapePtr MatMulInferShape(const PrimitivePtr &primitive, const std::ve | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| CheckAndConvertUtils::CheckInteger("matmul_infer_input", input_args.size(), kEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("matmul_infer_input", input_args.size(), kEqual, 2, prim_name); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| auto trans_a = GetValue<bool>(primitive->GetAttr(kTransposeA)); | auto trans_a = GetValue<bool>(primitive->GetAttr(kTransposeA)); | ||||
| auto trans_b = GetValue<bool>(primitive->GetAttr(kTransposeB)); | auto trans_b = GetValue<bool>(primitive->GetAttr(kTransposeB)); | ||||
| @@ -30,9 +30,8 @@ namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto assist_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("assist_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto assist_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("assist rank", (int64_t)assist_shape.size(), kGreaterEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("assist rank", (int64_t)assist_shape.size(), kGreaterEqual, 2, prim_name); | ||||
| CheckAndConvertUtils::Check("x_shape rank", (int64_t)x_shape.size() + 1, kLessEqual, "assist rank", | CheckAndConvertUtils::Check("x_shape rank", (int64_t)x_shape.size() + 1, kLessEqual, "assist rank", | ||||
| @@ -82,7 +82,7 @@ namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | auto op_name = primitive->name(); | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | ||||
| if (format == NHWC) { | if (format == NHWC) { | ||||
| in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | ||||
| @@ -25,10 +25,8 @@ namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto first_input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto second_input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("second_input_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto second_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name); | CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name); | ||||
| std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1], | std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1], | ||||
| @@ -31,7 +31,7 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | auto op_name = primitive->name(); | ||||
| int64_t axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | int64_t axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name); | CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name); | ||||
| auto depth_val = GetValue<int64_t>(input_args[1]->BuildValue()); | auto depth_val = GetValue<int64_t>(input_args[1]->BuildValue()); | ||||
| CheckAndConvertUtils::CheckInteger("depth", depth_val, kGreaterEqual, 0, op_name); | CheckAndConvertUtils::CheckInteger("depth", depth_val, kGreaterEqual, 0, op_name); | ||||
| @@ -28,9 +28,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| auto input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(input_shape); | return std::make_shared<abstract::Shape>(input_shape); | ||||
| } | } | ||||
| @@ -27,8 +27,8 @@ namespace mindspore { | |||||
| namespace ops { | namespace ops { | ||||
| abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_LOG(INFO) << "Do infer shape for op " << op_name; | MS_LOG(INFO) << "Do infer shape for op " << op_name; | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | |||||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->GetShapeTrack(), op_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape]; | |||||
| if (x_shape == y_shape) { | if (x_shape == y_shape) { | ||||
| return std::make_shared<abstract::Shape>(x_shape); | return std::make_shared<abstract::Shape>(x_shape); | ||||
| } | } | ||||
| @@ -23,7 +23,7 @@ std::vector<int64_t> _get_pack_shape(std::vector<BaseShapePtr> x_shapes, std::ve | |||||
| std::string name) { | std::string name) { | ||||
| CheckAndConvertUtils::CheckInteger("len of input_x", (int64_t)x_shapes.size(), kGreaterEqual, 1, name); | CheckAndConvertUtils::CheckInteger("len of input_x", (int64_t)x_shapes.size(), kGreaterEqual, 1, name); | ||||
| CheckAndConvertUtils::CheckSubClass("input_x[0]", x_types[0], {TypeIdToType(kObjectTypeTensorType)}, name); | CheckAndConvertUtils::CheckSubClass("input_x[0]", x_types[0], {TypeIdToType(kObjectTypeTensorType)}, name); | ||||
| auto output_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape[0]", x_shapes[0], name); | |||||
| auto output_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shapes[0])[kShape]; | |||||
| int64_t rank_base = output_shape.size(); | int64_t rank_base = output_shape.size(); | ||||
| int64_t N = x_shapes.size(); | int64_t N = x_shapes.size(); | ||||
| // CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeBoth, {-rank_base-1, rank_base}, name); | // CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeBoth, {-rank_base-1, rank_base}, name); | ||||
| @@ -37,7 +37,7 @@ std::vector<int64_t> _get_pack_shape(std::vector<BaseShapePtr> x_shapes, std::ve | |||||
| MS_EXCEPTION_IF_NULL(type0); | MS_EXCEPTION_IF_NULL(type0); | ||||
| CheckAndConvertUtils::Check("x_type[" + std::to_string(i) + "]", type->type_id(), kEqual, "base", type0->type_id(), | CheckAndConvertUtils::Check("x_type[" + std::to_string(i) + "]", type->type_id(), kEqual, "base", type0->type_id(), | ||||
| name); | name); | ||||
| auto shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape" + std::to_string(i), x_shapes[i], name); | |||||
| auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shapes[i])[kShape]; | |||||
| if (shape != output_shape) { | if (shape != output_shape) { | ||||
| MS_EXCEPTION(ValueError) << "For '" + name + "' element " + std::to_string(i) + | MS_EXCEPTION(ValueError) << "For '" + name + "' element " + std::to_string(i) + | ||||
| "shape in input can't pack with first element."; | "shape in input can't pack with first element."; | ||||
| @@ -25,7 +25,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto paddings_attr = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings)); | auto paddings_attr = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings)); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Pad"); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("paddings_size", paddings_attr.size(), kEqual, int64_t(2 * x_shape.size()), | CheckAndConvertUtils::CheckInteger("paddings_size", paddings_attr.size(), kEqual, int64_t(2 * x_shape.size()), | ||||
| prim_name); | prim_name); | ||||
| int64_t size = paddings_attr.size(); | int64_t size = paddings_attr.size(); | ||||
| @@ -25,8 +25,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto x = input_args[0]->BuildShape(); | auto x = input_args[0]->BuildShape(); | ||||
| auto w = input_args[1]->BuildShape(); | auto w = input_args[1]->BuildShape(); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", x, prim_name); | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", w, prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x)[kShape]; | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(w)[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kNotEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kNotEqual, 1, prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 1, prim_name); | ||||
| @@ -112,7 +112,6 @@ void PriorBox::Init(const std::vector<int64_t> &min_sizes, const std::vector<int | |||||
| AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | MS_EXCEPTION_IF_NULL(input_args[0]); | ||||
| std::vector<float> different_aspect_ratios{1.0f}; | std::vector<float> different_aspect_ratios{1.0f}; | ||||
| auto aspect_ratios = GetValue<std::vector<float>>(primitive->GetAttr(kAspectRatios)); | auto aspect_ratios = GetValue<std::vector<float>>(primitive->GetAttr(kAspectRatios)); | ||||
| @@ -129,7 +128,7 @@ AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const Primiti | |||||
| } | } | ||||
| auto min_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kMinSizes)); | auto min_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kMinSizes)); | ||||
| int64_t num_priors_box = min_sizes.size() * different_aspect_ratios.size() + min_sizes.size(); | int64_t num_priors_box = min_sizes.size() * different_aspect_ratios.size() + min_sizes.size(); | ||||
| auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto input = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| int64_t h = input[0] * input[1] * num_priors_box * 4; | int64_t h = input[0] * input[1] * num_priors_box * 4; | ||||
| std::vector<int64_t> output_shape{1, h, 1, 2}; | std::vector<int64_t> output_shape{1, h, 1, 2}; | ||||
| return std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape); | return std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape); | ||||
| @@ -32,13 +32,12 @@ void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) { | |||||
| AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | MS_EXCEPTION_IF_NULL(input_args[0]); | ||||
| auto input_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); | auto input_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(input_type); | MS_EXCEPTION_IF_NULL(input_type); | ||||
| auto dst_type = GetValue<int64_t>(primitive->GetAttr(kDstT)); | auto dst_type = GetValue<int64_t>(primitive->GetAttr(kDstT)); | ||||
| MS_ASSERT(input_type->element() == TypeIdToType(TypeId(dst_type))); | MS_ASSERT(input_type->element() == TypeIdToType(TypeId(dst_type))); | ||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId(dst_type)), input_shape); | return std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId(dst_type)), input_shape); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_C(kNameQuantDTypeCast, QuantDTypeCast); | REGISTER_PRIMITIVE_C(kNameQuantDTypeCast, QuantDTypeCast); | ||||
| @@ -34,8 +34,7 @@ AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const Primi | |||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| // infer shape | // infer shape | ||||
| auto in_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), prim_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| // infer type | // infer type | ||||
| std::set<TypePtr> valid_x_type = {kTensorType}; | std::set<TypePtr> valid_x_type = {kTensorType}; | ||||
| auto x_type = CheckAndConvertUtils::CheckTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); | auto x_type = CheckAndConvertUtils::CheckTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); | ||||
| @@ -71,8 +71,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto input_x_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto keep_dims = GetValue<bool>(primitive->GetAttr(kKeepDims)); | auto keep_dims = GetValue<bool>(primitive->GetAttr(kKeepDims)); | ||||
| auto out_shape = infer_shape_reduce(input_x_shape, axis_value, keep_dims, prim_name); | auto out_shape = infer_shape_reduce(input_x_shape, axis_value, keep_dims, prim_name); | ||||
| @@ -49,8 +49,7 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P | |||||
| CheckAndConvertUtils::CheckInteger("resize_bilinear_infer", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("resize_bilinear_infer", input_args.size(), kEqual, 1, prim_name); | ||||
| // Infer shape | // Infer shape | ||||
| auto input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("input_shape_rank", input_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("input_shape_rank", input_shape.size(), kEqual, 4, prim_name); | ||||
| std::vector<int64_t> out_shape = {input_shape[0], input_shape[1]}; | std::vector<int64_t> out_shape = {input_shape[0], input_shape[1]}; | ||||
| auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSize)); | auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSize)); | ||||
| @@ -44,10 +44,8 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const | |||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| // infer shape | // infer shape | ||||
| auto input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto seq_lengths = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("seq_lengths", input_args[1]->BuildShape(), prim_name); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto seq_lengths = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| auto seq_dim = GetValue<int64_t>(primitive->GetAttr(kSeqDim)); | auto seq_dim = GetValue<int64_t>(primitive->GetAttr(kSeqDim)); | ||||
| auto batch_dim = GetValue<int64_t>(primitive->GetAttr(kBatchDim)); | auto batch_dim = GetValue<int64_t>(primitive->GetAttr(kBatchDim)); | ||||
| CheckAndConvertUtils::CheckInteger("seq_dim", seq_dim, kLessEqual, input_shape.size(), prim_name); | CheckAndConvertUtils::CheckInteger("seq_dim", seq_dim, kLessEqual, input_shape.size(), prim_name); | ||||
| @@ -24,8 +24,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(x_shape); | return std::make_shared<abstract::Shape>(x_shape); | ||||
| } | } | ||||
| @@ -24,9 +24,7 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| auto first_input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto out_shape = first_input_shape; | auto out_shape = first_input_shape; | ||||
| out_shape[out_shape.size() - 1] = GetValue<int64_t>(primitive->GetAttr(kFftLength)) / 2 + 1; | out_shape[out_shape.size() - 1] = GetValue<int64_t>(primitive->GetAttr(kFftLength)) / 2 + 1; | ||||
| out_shape.push_back(2); | out_shape.push_back(2); | ||||
| @@ -62,9 +62,8 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi | |||||
| // Infer shape | // Infer shape | ||||
| auto new_h = GetValue<int64_t>(primitive->GetAttr(kPooledH)); | auto new_h = GetValue<int64_t>(primitive->GetAttr(kPooledH)); | ||||
| auto new_w = GetValue<int64_t>(primitive->GetAttr(kPooledW)); | auto new_w = GetValue<int64_t>(primitive->GetAttr(kPooledW)); | ||||
| auto input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShape("roi_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| std::vector<int64_t> output_shape; | std::vector<int64_t> output_shape; | ||||
| output_shape.push_back(roi_shape[0]); | output_shape.push_back(roi_shape[0]); | ||||
| output_shape.push_back(new_h); | output_shape.push_back(new_h); | ||||
| @@ -23,7 +23,7 @@ namespace mindspore { | |||||
| namespace ops { | namespace ops { | ||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "round"); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(x_shape); | return std::make_shared<abstract::Shape>(x_shape); | ||||
| } | } | ||||
| @@ -30,7 +30,7 @@ namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->GetShapeTrack(), prim_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 1, prim_name); | ||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| } | } | ||||
| @@ -29,7 +29,7 @@ abstract::ShapePtr ScalarSummaryInferShape(const PrimitivePtr &primitive, | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| // check | // check | ||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name); | ||||
| return std::make_shared<abstract::Shape>(ShapeVector(1)); | return std::make_shared<abstract::Shape>(ShapeVector(1)); | ||||
| } | } | ||||
| @@ -29,10 +29,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| for (const auto &shape : shape_value_element) { | for (const auto &shape : shape_value_element) { | ||||
| CheckAndConvertUtils::CheckInteger("shape value", shape, kGreaterThan, 0, "ScatterNd"); | CheckAndConvertUtils::CheckInteger("shape value", shape, kGreaterThan, 0, "ScatterNd"); | ||||
| } | } | ||||
| auto indices_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("indices_shape", input_args[0]->BuildShape(), "ScatterNd"); | |||||
| auto update_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("update_shape", input_args[1]->BuildShape(), "ScatterNd"); | |||||
| auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("indices_shape[0] and update_shape[0]", indices_shape[0], kEqual, update_shape[0], | CheckAndConvertUtils::CheckInteger("indices_shape[0] and update_shape[0]", indices_shape[0], kEqual, update_shape[0], | ||||
| "ScatterNd"); | "ScatterNd"); | ||||
| return std::make_shared<abstract::Shape>(shape_value_element); | return std::make_shared<abstract::Shape>(shape_value_element); | ||||
| @@ -34,8 +34,8 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin | |||||
| prim_name); | prim_name); | ||||
| // Infer shape | // Infer shape | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError); | CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError); | ||||
| // Infer type | // Infer type | ||||
| @@ -31,7 +31,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Sin"); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(x_shape); | return std::make_shared<abstract::Shape>(x_shape); | ||||
| } | } | ||||
| @@ -23,7 +23,6 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| if (input_args.size() != 1) { | if (input_args.size() != 1) { | ||||
| MS_LOG(ERROR) << "Skip Gram should have one input"; | MS_LOG(ERROR) << "Skip Gram should have one input"; | ||||
| } | } | ||||
| @@ -31,7 +30,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| if (infer_value == nullptr) { | if (infer_value == nullptr) { | ||||
| MS_LOG(INFO) << "Do infer shape in runtime."; | MS_LOG(INFO) << "Do infer shape in runtime."; | ||||
| } | } | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| return std::make_shared<abstract::Shape>(in_shape); | return std::make_shared<abstract::Shape>(in_shape); | ||||
| } | } | ||||
| @@ -40,8 +40,8 @@ AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const Pri | |||||
| CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name); | ||||
| // Infer shape | // Infer shape | ||||
| auto prediction = CheckAndConvertUtils::ConvertShapePtrToShape("prediction", input_args[0]->BuildShape(), prim_name); | |||||
| auto target = CheckAndConvertUtils::ConvertShapePtrToShape("target", input_args[0]->BuildShape(), prim_name); | |||||
| auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError); | CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError); | ||||
| // Infer type | // Infer type | ||||
| @@ -34,10 +34,8 @@ AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin | |||||
| prim_name); | prim_name); | ||||
| // Infer shape | // Infer shape | ||||
| auto logits_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("logits_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto labels_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("labels_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto logits_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto labels_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::Check("logits shape", logits_shape, kEqual, "labels shape", labels_shape, prim_name, TypeError); | CheckAndConvertUtils::Check("logits shape", logits_shape, kEqual, "labels shape", labels_shape, prim_name, TypeError); | ||||
| std::vector<int64_t> loss_shape = {logits_shape[0]}; | std::vector<int64_t> loss_shape = {logits_shape[0]}; | ||||
| auto dlogits_shape = logits_shape; | auto dlogits_shape = logits_shape; | ||||
| @@ -29,8 +29,7 @@ namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("input shape", input_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("input shape", input_shape.size(), kEqual, 4, prim_name); | ||||
| std::vector<int64_t> output_shape(input_shape.size()); | std::vector<int64_t> output_shape(input_shape.size()); | ||||
| auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize)); | auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize)); | ||||
| @@ -29,7 +29,7 @@ namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); | ||||
| auto out_shape = x_shape; | auto out_shape = x_shape; | ||||
| int64_t block_shape_prod = 1; | int64_t block_shape_prod = 1; | ||||
| @@ -43,8 +43,7 @@ AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::Analysi | |||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| // infer shape | // infer shape | ||||
| auto input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| std::vector<int64_t> output_shape; | std::vector<int64_t> output_shape; | ||||
| if (GetValue<bool>(primitive->GetAttr(kIsGrad)) != 0) { | if (GetValue<bool>(primitive->GetAttr(kIsGrad)) != 0) { | ||||
| output_shape = input_shape; | output_shape = input_shape; | ||||
| @@ -33,8 +33,7 @@ AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const Pr | |||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| // infer shape | // infer shape | ||||
| auto dense_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("dense_shape", input_args[3]->BuildShape(), prim_name); | |||||
| auto dense_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape]; | |||||
| // infer type | // infer type | ||||
| auto values_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element(); | auto values_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element(); | ||||
| return std::make_shared<abstract::AbstractTensor>(values_type, dense_shape); | return std::make_shared<abstract::AbstractTensor>(values_type, dense_shape); | ||||
| @@ -29,7 +29,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis)); | auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis)); | ||||
| std::vector<int64_t> infer_shape; | std::vector<int64_t> infer_shape; | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name); | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; | |||||
| auto len = SizeToLong(in_shape.size()); | auto len = SizeToLong(in_shape.size()); | ||||
| if (axis.empty()) { | if (axis.empty()) { | ||||
| std::copy_if(in_shape.begin(), in_shape.end(), std::back_inserter(infer_shape), | std::copy_if(in_shape.begin(), in_shape.end(), std::back_inserter(infer_shape), | ||||
| @@ -21,7 +21,6 @@ namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| if (input_args.size() != 1) { | if (input_args.size() != 1) { | ||||
| MS_LOG(ERROR) << "Invalid output size:" << input_args.size(); | MS_LOG(ERROR) << "Invalid output size:" << input_args.size(); | ||||
| @@ -29,11 +28,9 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v | |||||
| if (input_args.size() < 1) { | if (input_args.size() < 1) { | ||||
| MS_LOG(ERROR) << "Invalid input size " << input_args.size(); | MS_LOG(ERROR) << "Invalid input size " << input_args.size(); | ||||
| } | } | ||||
| auto input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| for (int64_t i = 1; i < (int64_t)input_args.size(); ++i) { | for (int64_t i = 1; i < (int64_t)input_args.size(); ++i) { | ||||
| auto input_shape_tmp = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[i]->BuildShape(), prim_name); | |||||
| auto input_shape_tmp = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape]; | |||||
| if (input_shape_tmp.size() != input_shape.size()) { | if (input_shape_tmp.size() != input_shape.size()) { | ||||
| MS_LOG(ERROR) << "All input shape size should be the same!"; | MS_LOG(ERROR) << "All input shape size should be the same!"; | ||||
| } | } | ||||
| @@ -108,7 +108,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive, | |||||
| auto temp_strides_v = input_args[3]->cast<abstract::AbstractTuplePtr>()->BuildValue(); | auto temp_strides_v = input_args[3]->cast<abstract::AbstractTuplePtr>()->BuildValue(); | ||||
| auto strides_v = GetValue<std::vector<int64_t>>(temp_strides_v); | auto strides_v = GetValue<std::vector<int64_t>>(temp_strides_v); | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| int64_t x_rank = x_shape.size(); | int64_t x_rank = x_shape.size(); | ||||
| int64_t slice_len = begin_v.size(); | int64_t slice_len = begin_v.size(); | ||||
| std::vector<int64_t> begin_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kBeginMask))); | std::vector<int64_t> begin_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kBeginMask))); | ||||
| @@ -33,7 +33,7 @@ AbstractBasePtr TanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr | |||||
| CheckAndConvertUtils::CheckInteger("tan_infer", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("tan_infer", input_args.size(), kEqual, 1, prim_name); | ||||
| // Infer Shape | // Infer Shape | ||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto infer_shape = std::make_shared<abstract::Shape>(x_shape); | auto infer_shape = std::make_shared<abstract::Shape>(x_shape); | ||||
| // Infer Type | // Infer Type | ||||
| @@ -24,11 +24,8 @@ namespace { | |||||
| abstract::ShapePtr TensorListFromTensorInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr TensorListFromTensorInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| auto input0_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input0 shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto input1_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input1 shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| if (input0_shape.size() < 1) { | if (input0_shape.size() < 1) { | ||||
| MS_LOG(ERROR) << "input0_shape.size():" << input0_shape.size() << " must be greater than 0!"; | MS_LOG(ERROR) << "input0_shape.size():" << input0_shape.size() << " must be greater than 0!"; | ||||
| } | } | ||||
| @@ -52,9 +52,7 @@ AbstractBasePtr TensorListStackInfer(const abstract::AnalysisEnginePtr &, const | |||||
| for (const auto &input : input_args) { | for (const auto &input : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(input); | MS_EXCEPTION_IF_NULL(input); | ||||
| } | } | ||||
| auto op_name = primitive->name(); | |||||
| auto input0_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||||
| int64_t num = std::accumulate(input0_shape.begin(), input0_shape.end(), 1LL, std::multiplies<int64_t>()); | int64_t num = std::accumulate(input0_shape.begin(), input0_shape.end(), 1LL, std::multiplies<int64_t>()); | ||||
| if (num == 0) { | if (num == 0) { | ||||
| MS_LOG(ERROR) << "Try to stack a empty tensorlist!"; | MS_LOG(ERROR) << "Try to stack a empty tensorlist!"; | ||||
| @@ -62,8 +60,7 @@ AbstractBasePtr TensorListStackInfer(const abstract::AnalysisEnginePtr &, const | |||||
| if (input_args[1]->BuildShape() == nullptr) { | if (input_args[1]->BuildShape() == nullptr) { | ||||
| MS_LOG(ERROR) << "ele_shape->data_c() is nullptr"; | MS_LOG(ERROR) << "ele_shape->data_c() is nullptr"; | ||||
| } | } | ||||
| auto input1_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name); | |||||
| auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| input1_shape.insert(input1_shape.begin(), 1); | input1_shape.insert(input1_shape.begin(), 1); | ||||
| return std::make_shared<abstract::AbstractTensor>(input_args[0]->BuildType(), input1_shape); | return std::make_shared<abstract::AbstractTensor>(input_args[0]->BuildType(), input1_shape); | ||||
| } | } | ||||
| @@ -29,7 +29,7 @@ abstract::ShapePtr TensorSummaryInferShape(const PrimitivePtr &primitive, | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| // check | // check | ||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||||
| CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name); | ||||
| return std::make_shared<abstract::Shape>(ShapeVector(1)); | return std::make_shared<abstract::Shape>(ShapeVector(1)); | ||||
| } | } | ||||