From: @simson_wu Reviewed-by: @chujinjin,@zh_qh Signed-off-by: @zh_qhpull/14530/MERGE
| @@ -363,7 +363,7 @@ set_target_properties(_c_expression PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH}) | |||||
| if(CMAKE_SYSTEM_NAME MATCHES "Windows") | if(CMAKE_SYSTEM_NAME MATCHES "Windows") | ||||
| target_link_libraries(mindspore mindspore::pybind11_module) | target_link_libraries(mindspore mindspore::pybind11_module) | ||||
| target_link_libraries(mindspore mindspore_gvar) | target_link_libraries(mindspore mindspore_gvar) | ||||
| target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive) | |||||
| target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore mindspore_core -Wl,--no-whole-archive) | |||||
| elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") | elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") | ||||
| target_link_libraries(mindspore mindspore::pybind11_module) | target_link_libraries(mindspore mindspore::pybind11_module) | ||||
| target_link_libraries(mindspore mindspore_gvar) | target_link_libraries(mindspore mindspore_gvar) | ||||
| @@ -459,7 +459,7 @@ AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynam | |||||
| std::vector<size_t> shape = {t_size, IntToSize(1), n_size}; | std::vector<size_t> shape = {t_size, IntToSize(1), n_size}; | ||||
| std::vector<int64_t> output_shape = {SizeToLong(t_size), SizeToLong(1), SizeToLong(n_size)}; | std::vector<int64_t> output_shape = {SizeToLong(t_size), SizeToLong(1), SizeToLong(n_size)}; | ||||
| std::vector<int64_t> output_tensor = {SizeToLong(t_size) * SizeToLong(n_size)}; | std::vector<int64_t> output_tensor = {SizeToLong(t_size) * SizeToLong(n_size)}; | ||||
| auto tensor = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, output_tensor); | |||||
| auto tensor = TensorConstructUtils::CreateOnesTensor(kFloat32, output_tensor); | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape); | ||||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | ||||
| auto value_node = kernel_graph->NewValueNode(x_abstract, tensor); | auto value_node = kernel_graph->NewValueNode(x_abstract, tensor); | ||||
| @@ -287,6 +287,7 @@ inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("Elu"); | |||||
| inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6"); | inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6"); | ||||
| inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | ||||
| inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU"); | inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU"); | ||||
| inline const PrimitivePtr kPrimZeros = std::make_shared<Primitive>("Zeros"); | |||||
| inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike"); | inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike"); | ||||
| inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike"); | inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike"); | ||||
| inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut"); | inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut"); | ||||
| @@ -29,7 +29,7 @@ abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::v | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | auto prim_name = primitive->name(); | ||||
| // check | // check | ||||
| CheckAndConvertUtils::CheckInteger("biasadd_infer", input_args.size(), kEqual, 2, prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("arg size", input_args.size(), kEqual, 2, prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name); | auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name); | CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name); | ||||
| @@ -55,7 +55,7 @@ TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBas | |||||
| std::map<std::string, TypePtr> types; | std::map<std::string, TypePtr> types; | ||||
| types.emplace("input_x", input_args[0]->BuildType()); | types.emplace("input_x", input_args[0]->BuildType()); | ||||
| types.emplace("bias", input_args[1]->BuildType()); | types.emplace("bias", input_args[1]->BuildType()); | ||||
| return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); | |||||
| return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| void BiasAdd::set_format(const Format &format) { | void BiasAdd::set_format(const Format &format) { | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -20,28 +20,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("gather_infer", input_args.size(), kEqual, 3, prim_name); | |||||
| // Infer type | |||||
| std::set<TypePtr> valid_x_type = {kTensorType}; | |||||
| auto x_type = | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); | |||||
| std::set<TypePtr> valid_index_types = {kInt32, kInt64}; | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_index_types, prim_name); | |||||
| std::set<TypePtr> valid_dim_type = {kInt32, kInt64}; | |||||
| CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_dim_type, prim_name); | |||||
| // Infer shape | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name); | |||||
| CheckAndConvertUtils::Check("x_rank", x_shape.size(), kEqual, "index_rank", index_shape.size(), prim_name); | |||||
| return std::make_shared<abstract::AbstractTensor>(x_type, index_shape); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameGather, Gather); | REGISTER_PRIMITIVE_C(kNameGather, Gather); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -34,8 +34,6 @@ class Gather : public PrimitiveC { | |||||
| MS_DECLARE_PARENT(Gather, PrimitiveC); | MS_DECLARE_PARENT(Gather, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| }; | }; | ||||
| AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimGatherPtr = std::shared_ptr<Gather>; | using PrimGatherPtr = std::shared_ptr<Gather>; | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,75 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/gather_d.h" | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| // gather_d | |||||
| namespace { | |||||
| abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| // check | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name); | |||||
| int64_t x_rank = x_shape.size(); | |||||
| CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name); | |||||
| auto dim_v = GetValue<int64_t>(input_args[1]->BuildValue()); | |||||
| CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, "negative index_rank", -x_rank, prim_name); | |||||
| CheckAndConvertUtils::Check("dim value", dim_v, kLessThan, "index_rank", x_rank, prim_name); | |||||
| if (dim_v < 0) { | |||||
| dim_v = dim_v + x_rank; | |||||
| } | |||||
| for (int i = 0; i < x_rank; ++i) { | |||||
| if (i == dim_v) continue; | |||||
| MS_LOG(INFO) << "Check " << i << "th x shape"; | |||||
| CheckAndConvertUtils::Check("x shape", x_shape[i], kEqual, "index_rank", index_shape[i], prim_name); | |||||
| } | |||||
| return std::make_shared<abstract::Shape>(index_shape); | |||||
| } | |||||
| TypePtr GatherDInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto prim_name = prim->name(); | |||||
| // check | |||||
| std::set<TypePtr> valid_x_type = {kTensorType}; | |||||
| auto x_type = | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); | |||||
| std::set<TypePtr> valid_index_types = {kInt32, kInt64}; | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_index_types, prim_name); | |||||
| std::set<TypePtr> valid_dim_type = {kInt32, kInt64}; | |||||
| CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_dim_type, prim_name); | |||||
| return x_type; | |||||
| } | |||||
| } // namespace | |||||
| AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto abs = std::make_shared<abstract::AbstractTensor>(GatherDInferType(primitive, input_args), | |||||
| GatherDInferShape(primitive, input_args)); | |||||
| return abs; | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, false); | |||||
| REGISTER_PRIMITIVE_C(kNameGatherD, GatherD); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_OPS_GATHER_D_H_ | |||||
| #define MINDSPORE_CORE_OPS_GATHER_D_H_ | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameGatherD = "GatherD"; | |||||
| class GatherD : public PrimitiveC { | |||||
| public: | |||||
| GatherD() : PrimitiveC(kNameGatherD) { InitIOName({"x", "dim", "index"}, {"output"}); } | |||||
| ~GatherD() = default; | |||||
| MS_DECLARE_PARENT(GatherD, PrimitiveC); | |||||
| void Init() {} | |||||
| }; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_OPS_GATHER_D_H_ | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -22,7 +22,18 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| // scalar_summary | |||||
| namespace { | |||||
| abstract::ShapePtr ScalarSummaryInferShape(const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| // check | |||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name); | |||||
| return std::make_shared<abstract::Shape>(ShapeVector(1)); | |||||
| } | |||||
| } // namespace | |||||
| void ScalarSummary::set_side_effect_io() { this->AddAttr(kSideEffectIO, MakeValue(true)); } | void ScalarSummary::set_side_effect_io() { this->AddAttr(kSideEffectIO, MakeValue(true)); } | ||||
| bool ScalarSummary::get_side_effect_io() const { | bool ScalarSummary::get_side_effect_io() const { | ||||
| @@ -35,12 +46,9 @@ void ScalarSummary::Init() { this->set_side_effect_io(); } | |||||
| AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| // check | // check | ||||
| CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], prim_name); | |||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name); | |||||
| return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1))); | |||||
| CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name()); | |||||
| return std::make_shared<abstract::AbstractTensor>(kInt32, ScalarSummaryInferShape(primitive, input_args)); | |||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true); | REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true); | ||||
| REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary); | REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -22,7 +22,18 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| // scalar_summary | |||||
| namespace { | |||||
| abstract::ShapePtr TensorSummaryInferShape(const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| // check | |||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name); | |||||
| return std::make_shared<abstract::Shape>(ShapeVector(1)); | |||||
| } | |||||
| } // namespace | |||||
| void TensorSummary::set_side_effect_io() { this->AddAttr(kSideEffectIO, MakeValue(true)); } | void TensorSummary::set_side_effect_io() { this->AddAttr(kSideEffectIO, MakeValue(true)); } | ||||
| bool TensorSummary::get_side_effect_io() const { | bool TensorSummary::get_side_effect_io() const { | ||||
| @@ -35,12 +46,9 @@ void TensorSummary::Init() { this->set_side_effect_io(); } | |||||
| AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto prim_name = primitive->name(); | |||||
| // check | // check | ||||
| CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], prim_name); | |||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name); | |||||
| return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1))); | |||||
| CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name()); | |||||
| return std::make_shared<abstract::AbstractTensor>(kInt32, TensorSummaryInferShape(primitive, input_args)); | |||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true); | REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true); | ||||
| REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary); | REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -0,0 +1,75 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/zeros.h" | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "utils/tensor_construct_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| // zeros | |||||
| namespace { | |||||
| abstract::ShapePtr ZerosInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| // check | |||||
| auto shape_value = input_args[0]->BuildValue(); | |||||
| std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("shape", shape_value, prim_name); | |||||
| CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name); | |||||
| return std::make_shared<abstract::Shape>(out_shape); | |||||
| } | |||||
| TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto prim_name = prim->name(); | |||||
| // check | |||||
| auto dtype_value = input_args[1]->BuildValue(); | |||||
| if (!dtype_value->isa<Type>()) { | |||||
| MS_EXCEPTION(TypeError) << "The dtype of Zeros is invalid!"; | |||||
| } | |||||
| auto output_type = dtype_value->cast<TypePtr>(); | |||||
| const std::set<TypePtr> valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, | |||||
| kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; | |||||
| return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name); | |||||
| } | |||||
| ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args, | |||||
| const abstract::AbstractBasePtr &abs) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto prim_name = prim->name(); | |||||
| // check | |||||
| auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("output shape", abs->BuildShape(), prim_name); | |||||
| auto out_type = abs->BuildType(); | |||||
| MS_EXCEPTION_IF_NULL(out_type); | |||||
| return TensorConstructUtils::CreateZerosTensor(out_type, out_shape); | |||||
| } | |||||
| } // namespace | |||||
| AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto abs = std::make_shared<abstract::AbstractTensor>(ZerosInferType(primitive, input_args), | |||||
| ZerosInferShape(primitive, input_args)); | |||||
| abs->set_value(ZerosInferValue(primitive, input_args, abs)); | |||||
| return abs; | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false); | |||||
| REGISTER_PRIMITIVE_C(kNameZeros, Zeros); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_OPS_ZEROS_H_ | |||||
| #define MINDSPORE_CORE_OPS_ZEROS_H_ | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameZeros = "Zeros"; | |||||
| class Zeros : public PrimitiveC { | |||||
| public: | |||||
| Zeros() : PrimitiveC(kNameZeros) {} | |||||
| ~Zeros() = default; | |||||
| MS_DECLARE_PARENT(Zeros, PrimitiveC); | |||||
| void Init() {} | |||||
| }; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_OPS_ZEROS_H_ | |||||
| @@ -442,8 +442,8 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, Ty | |||||
| auto type = types.begin()->second; | auto type = types.begin()->second; | ||||
| MS_EXCEPTION_IF_NULL(type); | MS_EXCEPTION_IF_NULL(type); | ||||
| if (!type->isa<TensorType>()) { | if (!type->isa<TensorType>()) { | ||||
| MS_EXCEPTION(TypeError) << "The " << prim_name << "'s" << types.begin()->first << " input must be a tensor but got " | |||||
| << type->ToString(); | |||||
| MS_EXCEPTION(TypeError) << "The " << prim_name << "'s " << types.begin()->first | |||||
| << " input must be a tensor but got " << type->ToString(); | |||||
| } | } | ||||
| TypePtr check_type = _CheckTypeSame(types, prim_name, false); | TypePtr check_type = _CheckTypeSame(types, prim_name, false); | ||||
| return CheckTypeValid(types.begin()->first, check_type, check_list, prim_name); | return CheckTypeValid(types.begin()->first, check_type, check_list, prim_name); | ||||
| @@ -599,4 +599,27 @@ void CheckAndConvertUtils::CheckMode(const std::string &class_name) { | |||||
| MS_EXCEPTION(NotSupportError) << class_name << "operator does not support PyNative mode."; | MS_EXCEPTION(NotSupportError) << class_name << "operator does not support PyNative mode."; | ||||
| } | } | ||||
| } | } | ||||
| std::vector<int64_t> CheckAndConvertUtils::CheckAttrIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr, | |||||
| const std::string &prim_name) { | |||||
| std::vector<int64_t> result; | |||||
| MS_EXCEPTION_IF_NULL(attr); | |||||
| if (attr->isa<ValueTuple>()) { | |||||
| std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value(); | |||||
| (void)std::transform( | |||||
| attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t { | |||||
| if (!e->isa<Int64Imm>()) { | |||||
| MS_EXCEPTION(TypeError) << "For " << prim_name << ", the type of" << arg_name << " must be Int64"; | |||||
| } | |||||
| return GetValue<int64_t>(e); | |||||
| }); | |||||
| } else { | |||||
| if (!attr->isa<Int64Imm>()) { | |||||
| MS_EXCEPTION(TypeError) << "For " << prim_name << ", the type of" << arg_name << " must be Int64"; | |||||
| } | |||||
| int64_t attr_val = attr->cast<Int64ImmPtr>()->value(); | |||||
| result.push_back(attr_val); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -321,6 +321,8 @@ class CheckAndConvertUtils { | |||||
| static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value, | static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value, | ||||
| const std::string &class_name); | const std::string &class_name); | ||||
| static void CheckMode(const std::string &class_name); | static void CheckMode(const std::string &class_name); | ||||
| static std::vector<int64_t> CheckAttrIntOrTupleInt(const std::string &prim_name, const ValuePtr &attr, | |||||
| const std::string &arg_name); | |||||
| private: | private: | ||||
| static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2); | static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2); | ||||
| @@ -17,8 +17,10 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| namespace mindspore { | namespace mindspore { | ||||
| tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape) { | |||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | |||||
| tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) { | |||||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||||
| auto type_id = ExtractTypeId(type_ptr); | |||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape); | |||||
| size_t mem_size = IntToSize(tensor->ElementsNum()); | size_t mem_size = IntToSize(tensor->ElementsNum()); | ||||
| auto tensor_data = tensor->data_c(); | auto tensor_data = tensor->data_c(); | ||||
| char *data = reinterpret_cast<char *>(tensor_data); | char *data = reinterpret_cast<char *>(tensor_data); | ||||
| @@ -28,8 +30,10 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std | |||||
| return tensor; | return tensor; | ||||
| } | } | ||||
| tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape) { | |||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | |||||
| tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) { | |||||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||||
| auto type_id = ExtractTypeId(type_ptr); | |||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape); | |||||
| size_t mem_size = IntToSize(tensor->ElementsNum()); | size_t mem_size = IntToSize(tensor->ElementsNum()); | ||||
| if (tensor->data_type() == kNumberTypeFloat32) { | if (tensor->data_type() == kNumberTypeFloat32) { | ||||
| SetTensorData<float>(tensor->data_c(), 1.0, mem_size); | SetTensorData<float>(tensor->data_c(), 1.0, mem_size); | ||||
| @@ -39,8 +43,18 @@ tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std: | |||||
| return tensor; | return tensor; | ||||
| } | } | ||||
| tensor::TensorPtr TensorConstructUtils::CreateTensor(TypeId type, const std::vector<int64_t> &shape, void *data) { | |||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape, data, type); | |||||
| tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape, | |||||
| void *data) { | |||||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||||
| auto type_id = ExtractTypeId(type_ptr); | |||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape, data, type_id); | |||||
| return tensor; | return tensor; | ||||
| } | } | ||||
| TypeId TensorConstructUtils::ExtractTypeId(const TypePtr type_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||||
| auto tensor_type = type_ptr->cast<TensorTypePtr>(); | |||||
| auto type_id = tensor_type->element()->type_id(); | |||||
| return type_id; | |||||
| } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -30,9 +30,10 @@ void SetTensorData(void *data, T num, size_t data_length) { | |||||
| } | } | ||||
| class TensorConstructUtils { | class TensorConstructUtils { | ||||
| public: | public: | ||||
| static tensor::TensorPtr CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape); | |||||
| static tensor::TensorPtr CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape); | |||||
| static tensor::TensorPtr CreateTensor(TypeId type, const std::vector<int64_t> &shape, void *data); | |||||
| static tensor::TensorPtr CreateZerosTensor(const TypePtr type, const std::vector<int64_t> &shape); | |||||
| static tensor::TensorPtr CreateOnesTensor(const TypePtr type, const std::vector<int64_t> &shape); | |||||
| static tensor::TensorPtr CreateTensor(const TypePtr type, const std::vector<int64_t> &shape, void *data); | |||||
| static TypeId ExtractTypeId(const TypePtr type); | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ | #endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ | ||||
| @@ -1342,27 +1342,6 @@ class Zeros(PrimitiveWithInfer): | |||||
| def __init__(self): | def __init__(self): | ||||
| """Initialize Zeros""" | """Initialize Zeros""" | ||||
| def __infer__(self, dims, dtype): | |||||
| if isinstance(dims['value'], int): | |||||
| shape = (dims['value'],) | |||||
| else: | |||||
| shape = dims['value'] | |||||
| validator.check_value_type("shape", shape, [tuple], self.name) | |||||
| for i, item in enumerate(shape): | |||||
| validator.check_non_negative_int(item, shape[i], self.name) | |||||
| valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, | |||||
| mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, | |||||
| mstype.float16, mstype.float32, mstype.float64] | |||||
| validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name) | |||||
| x_nptype = mstype.dtype_to_nptype(dtype['value']) | |||||
| ret = np.zeros(shape, x_nptype) | |||||
| out = { | |||||
| 'value': Tensor(ret), | |||||
| 'shape': shape, | |||||
| 'dtype': x_nptype, | |||||
| } | |||||
| return out | |||||
| class OnesLike(PrimitiveWithInfer): | class OnesLike(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -5193,30 +5172,6 @@ class GatherD(PrimitiveWithInfer): | |||||
| """Initialize GatherD""" | """Initialize GatherD""" | ||||
| self.init_prim_io_names(inputs=['x', 'dim', 'index'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'dim', 'index'], outputs=['output']) | ||||
| def __infer__(self, x, dim, index): | |||||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | |||||
| validator.check_tensor_dtype_valid("index", index['dtype'], [mstype.int32, mstype.int64], self.name) | |||||
| validator.check_subclass("dim", dim['dtype'], [mstype.int32, mstype.int64], self.name) | |||||
| x_shp = x['shape'] | |||||
| idx_shp = index['shape'] | |||||
| x_rank = len(x_shp) | |||||
| idx_rank = len(idx_shp) | |||||
| validator.check("x_rank, idx_rank", x_rank, "expected", idx_rank, Rel.EQ, self.name) | |||||
| dim_v = dim['value'] | |||||
| validator.check("dim value", dim_v, "expected", -x_rank, Rel.GE, self.name) | |||||
| validator.check("dim value", dim_v, "expected", x_rank, Rel.LT, self.name) | |||||
| if dim_v < 0: | |||||
| dim['value'] = dim_v + x_rank | |||||
| for i in range(x_rank): | |||||
| if i == dim['value']: | |||||
| continue | |||||
| validator.check("x_shp[{0}], idx_shp[{0}]".format(i), x_shp[i], "expected", idx_shp[i], Rel.EQ, self.name) | |||||
| out = {'shape': index['shape'], | |||||
| 'dtype': x['dtype'], | |||||
| 'value': None} | |||||
| return out | |||||
| class Identity(PrimitiveWithInfer): | class Identity(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -89,17 +89,6 @@ class ScalarSummary(PrimitiveWithInfer): | |||||
| """init""" | """init""" | ||||
| self.add_prim_attr("side_effect_io", True) | self.add_prim_attr("side_effect_io", True) | ||||
| def __infer__(self, name, value): | |||||
| _check_summary_param(name, value, self.__class__.__name__) | |||||
| v_shape = value['shape'] | |||||
| # In the summary, the value whose shape is [1] is also considered as a scalar. | |||||
| if v_shape and v_shape != [1]: | |||||
| raise ValueError(f"For 'value' the type should be scalar, " | |||||
| f"shape should be [] or [1] in {self.__class__.__name__}, but got {v_shape}.") | |||||
| return SUMMARY_RETURN_VALUE | |||||
| class ImageSummary(PrimitiveWithInfer): | class ImageSummary(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -191,17 +180,6 @@ class TensorSummary(PrimitiveWithInfer): | |||||
| """init""" | """init""" | ||||
| self.add_prim_attr("side_effect_io", True) | self.add_prim_attr("side_effect_io", True) | ||||
| def __infer__(self, name, value): | |||||
| _check_summary_param(name, value, self.__class__.__name__) | |||||
| v_shape = value['shape'] | |||||
| # In the summary, the value whose shape is [] is not considered as a tensor. | |||||
| if not v_shape: | |||||
| raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, " | |||||
| f"shape should not be [].") | |||||
| return SUMMARY_RETURN_VALUE | |||||
| class HistogramSummary(PrimitiveWithInfer): | class HistogramSummary(PrimitiveWithInfer): | ||||
| """ | """ | ||||