diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 6eba48d1c7..1c94397526 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -363,7 +363,7 @@ set_target_properties(_c_expression PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH}) if(CMAKE_SYSTEM_NAME MATCHES "Windows") target_link_libraries(mindspore mindspore::pybind11_module) target_link_libraries(mindspore mindspore_gvar) - target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive) + target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore mindspore_core -Wl,--no-whole-archive) elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") target_link_libraries(mindspore mindspore::pybind11_module) target_link_libraries(mindspore mindspore_gvar) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc index 3ad53404b9..db2f833708 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.cc @@ -459,7 +459,7 @@ AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynam std::vector shape = {t_size, IntToSize(1), n_size}; std::vector output_shape = {SizeToLong(t_size), SizeToLong(1), SizeToLong(n_size)}; std::vector output_tensor = {SizeToLong(t_size) * SizeToLong(n_size)}; - auto tensor = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, output_tensor); + auto tensor = TensorConstructUtils::CreateOnesTensor(kFloat32, output_tensor); auto x_abstract = std::make_shared(kFloat32, output_shape); auto kernel_graph = func_graph->cast(); auto value_node = kernel_graph->NewValueNode(x_abstract, tensor); diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 181195b06b..dd6b225f1f 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -287,6 +287,7 @@ inline const PrimitivePtr kPrimElu = std::make_shared("Elu"); inline const PrimitivePtr kPrimRelu6 = std::make_shared("ReLU6"); inline const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); inline const PrimitivePtr kPrimPRelu = std::make_shared("PReLU"); +inline const PrimitivePtr kPrimZeros = std::make_shared("Zeros"); inline const PrimitivePtr kPrimZerosLike = std::make_shared("ZerosLike"); inline const PrimitivePtr kPrimOnesLike = std::make_shared("OnesLike"); inline const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut"); diff --git a/mindspore/core/ops/bias_add.cc b/mindspore/core/ops/bias_add.cc index b6660b28e0..d1bb713b27 100644 --- a/mindspore/core/ops/bias_add.cc +++ b/mindspore/core/ops/bias_add.cc @@ -29,7 +29,7 @@ abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::v MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); // check - CheckAndConvertUtils::CheckInteger("biasadd_infer", input_args.size(), kEqual, 2, prim_name); + CheckAndConvertUtils::CheckInteger("arg size", input_args.size(), kEqual, 2, prim_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name); CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name); @@ -55,7 +55,7 @@ TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector types; types.emplace("input_x", input_args[0]->BuildType()); types.emplace("bias", input_args[1]->BuildType()); - return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name); } } // namespace void BiasAdd::set_format(const Format &format) { diff --git a/mindspore/core/ops/gather.cc b/mindspore/core/ops/gather.cc index 31ec4947b5..47a274ebc6 100644 --- a/mindspore/core/ops/gather.cc +++ b/mindspore/core/ops/gather.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,28 +20,6 @@ namespace mindspore { namespace ops { -AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - CheckAndConvertUtils::CheckInteger("gather_infer", input_args.size(), kEqual, 3, prim_name); - - // Infer type - std::set valid_x_type = {kTensorType}; - auto x_type = - CheckAndConvertUtils::CheckTensorTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); - std::set valid_index_types = {kInt32, kInt64}; - CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_index_types, prim_name); - std::set valid_dim_type = {kInt32, kInt64}; - CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_dim_type, prim_name); - - // Infer shape - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name); - CheckAndConvertUtils::Check("x_rank", x_shape.size(), kEqual, "index_rank", index_shape.size(), prim_name); - - return std::make_shared(x_type, index_shape); -} REGISTER_PRIMITIVE_C(kNameGather, Gather); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/gather.h b/mindspore/core/ops/gather.h index aa1e88a4f9..55c735f77c 100644 --- a/mindspore/core/ops/gather.h +++ b/mindspore/core/ops/gather.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,8 +34,6 @@ class Gather : public PrimitiveC { MS_DECLARE_PARENT(Gather, PrimitiveC); void Init() {} }; -AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); using PrimGatherPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/gather_d.cc b/mindspore/core/ops/gather_d.cc new file mode 100644 index 0000000000..b195977a9c --- /dev/null +++ b/mindspore/core/ops/gather_d.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/gather_d.h" +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +// gather_d +namespace { +abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + // check + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name); + int64_t x_rank = x_shape.size(); + CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name); + auto dim_v = GetValue(input_args[1]->BuildValue()); + CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, "negative index_rank", -x_rank, prim_name); + CheckAndConvertUtils::Check("dim value", dim_v, kLessThan, "index_rank", x_rank, prim_name); + + if (dim_v < 0) { + dim_v = dim_v + x_rank; + } + for (int i = 0; i < x_rank; ++i) { + if (i == dim_v) continue; + MS_LOG(INFO) << "Check " << i << "th x shape"; + CheckAndConvertUtils::Check("x shape", x_shape[i], kEqual, "index_rank", index_shape[i], prim_name); + } + return std::make_shared(index_shape); +} + +TypePtr GatherDInferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + // check + std::set valid_x_type = {kTensorType}; + auto x_type = + CheckAndConvertUtils::CheckTensorTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); + std::set valid_index_types = {kInt32, kInt64}; + CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_index_types, prim_name); + std::set valid_dim_type = {kInt32, kInt64}; + CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_dim_type, prim_name); + return x_type; +} +} // namespace +AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto abs = std::make_shared(GatherDInferType(primitive, input_args), + GatherDInferShape(primitive, input_args)); + return abs; +} +REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, false); +REGISTER_PRIMITIVE_C(kNameGatherD, GatherD); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/gather_d.h b/mindspore/core/ops/gather_d.h new file mode 100644 index 0000000000..d39958b118 --- /dev/null +++ b/mindspore/core/ops/gather_d.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_GATHER_D_H_ +#define MINDSPORE_CORE_OPS_GATHER_D_H_ +#include +#include +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" +#include "ops/op_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameGatherD = "GatherD"; +class GatherD : public PrimitiveC { + public: + GatherD() : PrimitiveC(kNameGatherD) { InitIOName({"x", "dim", "index"}, {"output"}); } + ~GatherD() = default; + MS_DECLARE_PARENT(GatherD, PrimitiveC); + void Init() {} +}; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_GATHER_D_H_ diff --git a/mindspore/core/ops/scalar_summary.cc b/mindspore/core/ops/scalar_summary.cc index 342c4d4a43..03e0c87adc 100644 --- a/mindspore/core/ops/scalar_summary.cc +++ b/mindspore/core/ops/scalar_summary.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,18 @@ namespace mindspore { namespace ops { - +// scalar_summary +namespace { +abstract::ShapePtr ScalarSummaryInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + // check + auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name); + CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name); + return std::make_shared(ShapeVector(1)); +} +} // namespace void ScalarSummary::set_side_effect_io() { this->AddAttr(kSideEffectIO, MakeValue(true)); } bool ScalarSummary::get_side_effect_io() const { @@ -35,12 +46,9 @@ void ScalarSummary::Init() { this->set_side_effect_io(); } AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); // check - CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], prim_name); - auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name); - CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name); - return std::make_shared(kInt32, std::make_shared(ShapeVector(1))); + CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name()); + return std::make_shared(kInt32, ScalarSummaryInferShape(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true); REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary); diff --git a/mindspore/core/ops/scalar_summary.h b/mindspore/core/ops/scalar_summary.h index 332adef2fc..e89c85aca1 100644 --- a/mindspore/core/ops/scalar_summary.h +++ b/mindspore/core/ops/scalar_summary.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/core/ops/tensor_summary.cc b/mindspore/core/ops/tensor_summary.cc index 4b17b528e8..efe23d4e43 100644 --- a/mindspore/core/ops/tensor_summary.cc +++ b/mindspore/core/ops/tensor_summary.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,18 @@ namespace mindspore { namespace ops { - +// scalar_summary +namespace { +abstract::ShapePtr TensorSummaryInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + // check + auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name); + CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name); + return std::make_shared(ShapeVector(1)); +} +} // namespace void TensorSummary::set_side_effect_io() { this->AddAttr(kSideEffectIO, MakeValue(true)); } bool TensorSummary::get_side_effect_io() const { @@ -35,12 +46,9 @@ void TensorSummary::Init() { this->set_side_effect_io(); } AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); // check - CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], prim_name); - auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name); - CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name); - return std::make_shared(kInt32, std::make_shared(ShapeVector(1))); + CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name()); + return std::make_shared(kInt32, TensorSummaryInferShape(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true); REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary); diff --git a/mindspore/core/ops/tensor_summary.h b/mindspore/core/ops/tensor_summary.h index ec8d5b676a..61badeadb2 100644 --- a/mindspore/core/ops/tensor_summary.h +++ b/mindspore/core/ops/tensor_summary.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/core/ops/zeros.cc b/mindspore/core/ops/zeros.cc new file mode 100644 index 0000000000..e39a0a58a2 --- /dev/null +++ b/mindspore/core/ops/zeros.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/zeros.h" +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +// zeros +namespace { +abstract::ShapePtr ZerosInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + // check + auto shape_value = input_args[0]->BuildValue(); + std::vector out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("shape", shape_value, prim_name); + CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name); + return std::make_shared(out_shape); +} + +TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + // check + auto dtype_value = input_args[1]->BuildValue(); + if (!dtype_value->isa()) { + MS_EXCEPTION(TypeError) << "The dtype of Zeros is invalid!"; + } + auto output_type = dtype_value->cast(); + const std::set valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, + kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; + return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name); +} +ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector &input_args, + const abstract::AbstractBasePtr &abs) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + // check + auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("output shape", abs->BuildShape(), prim_name); + auto out_type = abs->BuildType(); + MS_EXCEPTION_IF_NULL(out_type); + return TensorConstructUtils::CreateZerosTensor(out_type, out_shape); +} +} // namespace + +AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto abs = std::make_shared(ZerosInferType(primitive, input_args), + ZerosInferShape(primitive, input_args)); + abs->set_value(ZerosInferValue(primitive, input_args, abs)); + return abs; +} +REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false); +REGISTER_PRIMITIVE_C(kNameZeros, Zeros); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/zeros.h b/mindspore/core/ops/zeros.h new file mode 100644 index 0000000000..b9afc5d8e2 --- /dev/null +++ b/mindspore/core/ops/zeros.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_ZEROS_H_ +#define MINDSPORE_CORE_OPS_ZEROS_H_ +#include +#include +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" +#include "ops/op_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameZeros = "Zeros"; +class Zeros : public PrimitiveC { + public: + Zeros() : PrimitiveC(kNameZeros) {} + ~Zeros() = default; + MS_DECLARE_PARENT(Zeros, PrimitiveC); + void Init() {} +}; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_ZEROS_H_ diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index 0aa3d3d706..8e2594165e 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -442,8 +442,8 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::mapsecond; MS_EXCEPTION_IF_NULL(type); if (!type->isa()) { - MS_EXCEPTION(TypeError) << "The " << prim_name << "'s" << types.begin()->first << " input must be a tensor but got " - << type->ToString(); + MS_EXCEPTION(TypeError) << "The " << prim_name << "'s " << types.begin()->first + << " input must be a tensor but got " << type->ToString(); } TypePtr check_type = _CheckTypeSame(types, prim_name, false); return CheckTypeValid(types.begin()->first, check_type, check_list, prim_name); @@ -599,4 +599,27 @@ void CheckAndConvertUtils::CheckMode(const std::string &class_name) { MS_EXCEPTION(NotSupportError) << class_name << "operator does not support PyNative mode."; } } + +std::vector CheckAndConvertUtils::CheckAttrIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr, + const std::string &prim_name) { + std::vector result; + MS_EXCEPTION_IF_NULL(attr); + if (attr->isa()) { + std::vector attr_vec = attr->cast()->value(); + (void)std::transform( + attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t { + if (!e->isa()) { + MS_EXCEPTION(TypeError) << "For " << prim_name << ", the type of" << arg_name << " must be Int64"; + } + return GetValue(e); + }); + } else { + if (!attr->isa()) { + MS_EXCEPTION(TypeError) << "For " << prim_name << ", the type of" << arg_name << " must be Int64"; + } + int64_t attr_val = attr->cast()->value(); + result.push_back(attr_val); + } + return result; +} } // namespace mindspore diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index 03bf225164..c3e76d13ee 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -321,6 +321,8 @@ class CheckAndConvertUtils { static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value, const std::string &class_name); static void CheckMode(const std::string &class_name); + static std::vector CheckAttrIntOrTupleInt(const std::string &prim_name, const ValuePtr &attr, + const std::string &arg_name); private: static bool IsEqualVector(const std::vector &vec_1, const std::vector &vec_2); diff --git a/mindspore/core/utils/tensor_construct_utils.cc b/mindspore/core/utils/tensor_construct_utils.cc index 1563ecc074..a34b205b28 100644 --- a/mindspore/core/utils/tensor_construct_utils.cc +++ b/mindspore/core/utils/tensor_construct_utils.cc @@ -17,8 +17,10 @@ #include #include namespace mindspore { -tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector &shape) { - tensor::TensorPtr tensor = std::make_shared(type, shape); +tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr, const std::vector &shape) { + MS_EXCEPTION_IF_NULL(type_ptr); + auto type_id = ExtractTypeId(type_ptr); + tensor::TensorPtr tensor = std::make_shared(type_id, shape); size_t mem_size = IntToSize(tensor->ElementsNum()); auto tensor_data = tensor->data_c(); char *data = reinterpret_cast(tensor_data); @@ -28,8 +30,10 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std return tensor; } -tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector &shape) { - tensor::TensorPtr tensor = std::make_shared(type, shape); +tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr, const std::vector &shape) { + MS_EXCEPTION_IF_NULL(type_ptr); + auto type_id = ExtractTypeId(type_ptr); + tensor::TensorPtr tensor = std::make_shared(type_id, shape); size_t mem_size = IntToSize(tensor->ElementsNum()); if (tensor->data_type() == kNumberTypeFloat32) { SetTensorData(tensor->data_c(), 1.0, mem_size); @@ -39,8 +43,18 @@ tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std: return tensor; } -tensor::TensorPtr TensorConstructUtils::CreateTensor(TypeId type, const std::vector &shape, void *data) { - tensor::TensorPtr tensor = std::make_shared(type, shape, data, type); +tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, const std::vector &shape, + void *data) { + MS_EXCEPTION_IF_NULL(type_ptr); + auto type_id = ExtractTypeId(type_ptr); + tensor::TensorPtr tensor = std::make_shared(type_id, shape, data, type_id); return tensor; } + +TypeId TensorConstructUtils::ExtractTypeId(const TypePtr type_ptr) { + MS_EXCEPTION_IF_NULL(type_ptr); + auto tensor_type = type_ptr->cast(); + auto type_id = tensor_type->element()->type_id(); + return type_id; +} } // namespace mindspore diff --git a/mindspore/core/utils/tensor_construct_utils.h b/mindspore/core/utils/tensor_construct_utils.h index 4bb87ab27b..72dac8f4cc 100644 --- a/mindspore/core/utils/tensor_construct_utils.h +++ b/mindspore/core/utils/tensor_construct_utils.h @@ -30,9 +30,10 @@ void SetTensorData(void *data, T num, size_t data_length) { } class TensorConstructUtils { public: - static tensor::TensorPtr CreateZerosTensor(TypeId type, const std::vector &shape); - static tensor::TensorPtr CreateOnesTensor(TypeId type, const std::vector &shape); - static tensor::TensorPtr CreateTensor(TypeId type, const std::vector &shape, void *data); + static tensor::TensorPtr CreateZerosTensor(const TypePtr type, const std::vector &shape); + static tensor::TensorPtr CreateOnesTensor(const TypePtr type, const std::vector &shape); + static tensor::TensorPtr CreateTensor(const TypePtr type, const std::vector &shape, void *data); + static TypeId ExtractTypeId(const TypePtr type); }; } // namespace mindspore #endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 0ad7cf83f6..f379ac50ee 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1342,27 +1342,6 @@ class Zeros(PrimitiveWithInfer): def __init__(self): """Initialize Zeros""" - def __infer__(self, dims, dtype): - if isinstance(dims['value'], int): - shape = (dims['value'],) - else: - shape = dims['value'] - validator.check_value_type("shape", shape, [tuple], self.name) - for i, item in enumerate(shape): - validator.check_non_negative_int(item, shape[i], self.name) - valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, - mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, - mstype.float16, mstype.float32, mstype.float64] - validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name) - x_nptype = mstype.dtype_to_nptype(dtype['value']) - ret = np.zeros(shape, x_nptype) - out = { - 'value': Tensor(ret), - 'shape': shape, - 'dtype': x_nptype, - } - return out - class OnesLike(PrimitiveWithInfer): """ @@ -5193,30 +5172,6 @@ class GatherD(PrimitiveWithInfer): """Initialize GatherD""" self.init_prim_io_names(inputs=['x', 'dim', 'index'], outputs=['output']) - def __infer__(self, x, dim, index): - validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) - validator.check_tensor_dtype_valid("index", index['dtype'], [mstype.int32, mstype.int64], self.name) - validator.check_subclass("dim", dim['dtype'], [mstype.int32, mstype.int64], self.name) - x_shp = x['shape'] - idx_shp = index['shape'] - x_rank = len(x_shp) - idx_rank = len(idx_shp) - validator.check("x_rank, idx_rank", x_rank, "expected", idx_rank, Rel.EQ, self.name) - dim_v = dim['value'] - validator.check("dim value", dim_v, "expected", -x_rank, Rel.GE, self.name) - validator.check("dim value", dim_v, "expected", x_rank, Rel.LT, self.name) - if dim_v < 0: - dim['value'] = dim_v + x_rank - for i in range(x_rank): - if i == dim['value']: - continue - validator.check("x_shp[{0}], idx_shp[{0}]".format(i), x_shp[i], "expected", idx_shp[i], Rel.EQ, self.name) - - out = {'shape': index['shape'], - 'dtype': x['dtype'], - 'value': None} - return out - class Identity(PrimitiveWithInfer): """ diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 1c2bddb37a..06b0bb2f60 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -89,17 +89,6 @@ class ScalarSummary(PrimitiveWithInfer): """init""" self.add_prim_attr("side_effect_io", True) - def __infer__(self, name, value): - _check_summary_param(name, value, self.__class__.__name__) - - v_shape = value['shape'] - # In the summary, the value whose shape is [1] is also considered as a scalar. - if v_shape and v_shape != [1]: - raise ValueError(f"For 'value' the type should be scalar, " - f"shape should be [] or [1] in {self.__class__.__name__}, but got {v_shape}.") - - return SUMMARY_RETURN_VALUE - class ImageSummary(PrimitiveWithInfer): """ @@ -191,17 +180,6 @@ class TensorSummary(PrimitiveWithInfer): """init""" self.add_prim_attr("side_effect_io", True) - def __infer__(self, name, value): - _check_summary_param(name, value, self.__class__.__name__) - - v_shape = value['shape'] - # In the summary, the value whose shape is [] is not considered as a tensor. - if not v_shape: - raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, " - f"shape should not be [].") - - return SUMMARY_RETURN_VALUE - class HistogramSummary(PrimitiveWithInfer): """