From b43bf90666e850c39331d2778ad8a57d2b4bbc90 Mon Sep 17 00:00:00 2001 From: simson Date: Thu, 15 Apr 2021 09:09:22 +0800 Subject: [PATCH] gather op infer --- .../core/abstract/primitive_infer_map.cc | 1 - mindspore/core/ops/broadcast_to.h | 3 +- mindspore/core/ops/gather.cc | 71 ++++++++++++++++++- mindspore/core/ops/gather.h | 2 +- mindspore/core/ops/gather_d.cc | 4 +- mindspore/core/ops/gather_d.h | 3 +- mindspore/core/ops/log1p.h | 3 +- mindspore/core/ops/scalar_summary.h | 3 +- mindspore/core/ops/softplus.h | 3 +- mindspore/core/ops/tensor_summary.h | 3 +- mindspore/core/ops/zeros.h | 3 +- mindspore/ops/operations/array_ops.py | 11 +-- 12 files changed, 82 insertions(+), 28 deletions(-) diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index cb9fd5fa5e..2728f855e9 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -77,7 +77,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimBroadcastShape, {InferImplBroadCastShape, nullptr, true}}, {prim::kPrimUnique, {InferImplUnique, nullptr, true}}, {prim::kPrimUniqueGrad, {InferImplUniqueGrad, nullptr, true}}, - {prim::kPrimGather, {InferImplGatherV2, nullptr, true}}, {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, nullptr, true}}, {prim::kPrimSparseGatherV2, {InferImplGatherV2, nullptr, true}}, {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, nullptr, true}}, diff --git a/mindspore/core/ops/broadcast_to.h b/mindspore/core/ops/broadcast_to.h index bb678a1f39..eff8abdcd0 100644 --- a/mindspore/core/ops/broadcast_to.h +++ b/mindspore/core/ops/broadcast_to.h @@ -27,10 +27,9 @@ namespace mindspore { namespace ops { -constexpr auto kNameBroadcastTo = "BroadcastTo"; class BroadcastTo : public PrimitiveC { public: - BroadcastTo() : PrimitiveC(kNameBroadcastTo) {} + BroadcastTo() : PrimitiveC(prim::kPrimBroadcastTo->name()) {} ~BroadcastTo() = default; MS_DECLARE_PARENT(BroadcastTo, PrimitiveC); void Init(const std::vector &shape); diff --git a/mindspore/core/ops/gather.cc b/mindspore/core/ops/gather.cc index 47a274ebc6..c6d9e5993d 100644 --- a/mindspore/core/ops/gather.cc +++ b/mindspore/core/ops/gather.cc @@ -16,10 +16,79 @@ #include #include +#include #include "ops/gather.h" namespace mindspore { namespace ops { -REGISTER_PRIMITIVE_C(kNameGather, Gather); +// gather +AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const std::string &op_name = primitive->name(); + abstract::CheckArgsSize(op_name, input_args, 3); + abstract::AbstractTensorPtr params = + CheckAndConvertUtils::CheckArgs(op_name, input_args, 0); + abstract::AbstractTensorPtr indices = + CheckAndConvertUtils::CheckArgs(op_name, input_args, 1); + // check + std::set valid_params_types = {kTensorType}; + CheckAndConvertUtils::CheckSubClass("params_type", input_args[0]->BuildType(), valid_params_types, op_name); + std::set int_types = {kInt8, kInt16, kInt32, kInt64}; + CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[1]->BuildType(), int_types, op_name); + CheckAndConvertUtils::CheckTypeValid("axis_type", input_args[2]->BuildType(), int_types, op_name); + + bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()); + bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty()); + int64_t axis_val = 0; + // 3rd input is a Tensor when Gather is a dynamic shape operator + if (input_args[2]->isa()) { + auto axis = input_args[2]->cast(); + MS_EXCEPTION_IF_NULL(axis); + auto axis_value_ptr = axis->BuildValue(); + MS_EXCEPTION_IF_NULL(axis_value_ptr); + auto axis_tensor = axis_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(axis_tensor); + axis_val = *static_cast(axis_tensor->data_c()); + } else if (input_args[2]->isa()) { + auto axis = input_args[2]->cast(); + axis_val = GetValue(axis->BuildValue()); + } else { + MS_LOG(EXCEPTION) << "Invalid abstract type:" << input_args[2]->type_name(); + } + auto params_shp = params->shape()->shape(); + auto indices_shp = indices->shape()->shape(); + auto params_rank = static_cast(params_shp.size()); + CheckAndConvertUtils::CheckInRange("axis", axis_val, kIncludeLeft, {-params_rank, params_rank}, op_name); + // either inputs or both can be dynamic and computation requires min/max shapes for both + ShapeVector param_shp_min = (param_dyn) ? params->shape()->min_shape() : params->shape()->shape(); + ShapeVector param_shp_max = (param_dyn) ? params->shape()->max_shape() : params->shape()->shape(); + ShapeVector indices_shp_min = (ind_dyn) ? indices->shape()->min_shape() : indices->shape()->shape(); + ShapeVector indices_shp_max = (ind_dyn) ? indices->shape()->max_shape() : indices->shape()->shape(); + // check axis_val within interval: [-params_rank, params_rank) + if (!(-params_rank <= axis_val) || !(axis_val < params_rank)) { + MS_LOG(EXCEPTION) << "For Gather - Axis value must be within [ " << -params_rank << ", " << params_rank << " ) " + << "Got " << axis_val << "."; + } + if (axis_val < 0) { + axis_val += params_rank; + } + auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector ¶ms_vec) -> ShapeVector { + ShapeVector out_vec; + std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec)); + copy(ind_vec.begin(), ind_vec.end(), std::back_inserter(out_vec)); + copy(params_vec.begin() + axis_val + 1, params_vec.end(), std::back_inserter(out_vec)); + return out_vec; + }; + ShapeVector out_shape = calc_shape(indices_shp, params_shp); + if (ind_dyn || param_dyn) { + ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min); + ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max); + return std::make_shared( + params->element(), std::make_shared(out_shape, min_shape, max_shape)); + } + return std::make_shared(params->element(), std::make_shared(out_shape)); +} +REGISTER_PRIMITIVE_EVAL_IMPL(Gather, prim::kPrimGather, GatherInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/gather.h b/mindspore/core/ops/gather.h index 55c735f77c..ea46370cf3 100644 --- a/mindspore/core/ops/gather.h +++ b/mindspore/core/ops/gather.h @@ -29,7 +29,7 @@ namespace ops { constexpr auto kNameGather = "Gather"; class Gather : public PrimitiveC { public: - Gather() : PrimitiveC(kNameGather) { InitIOName({"x", "dim", "index"}, {"output"}); } + Gather() : PrimitiveC(kNameGather) { InitIOName({"param", "indices", "axis"}, {"output"}); } ~Gather() = default; MS_DECLARE_PARENT(Gather, PrimitiveC); void Init() {} diff --git a/mindspore/core/ops/gather_d.cc b/mindspore/core/ops/gather_d.cc index 4d14e3d66e..e4a4aad758 100644 --- a/mindspore/core/ops/gather_d.cc +++ b/mindspore/core/ops/gather_d.cc @@ -33,7 +33,9 @@ abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::v auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name); int64_t x_rank = x_shape.size(); CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name); - auto dim_v = GetValue(input_args[1]->BuildValue()); + auto value_ptr = input_args[1]->BuildValue(); + MS_EXCEPTION_IF_NULL(value_ptr); + auto dim_v = GetValue(value_ptr); CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, "negative index_rank", -x_rank, prim_name); CheckAndConvertUtils::Check("dim value", dim_v, kLessThan, "index_rank", x_rank, prim_name); diff --git a/mindspore/core/ops/gather_d.h b/mindspore/core/ops/gather_d.h index d39958b118..76021af100 100644 --- a/mindspore/core/ops/gather_d.h +++ b/mindspore/core/ops/gather_d.h @@ -27,10 +27,9 @@ namespace mindspore { namespace ops { -constexpr auto kNameGatherD = "GatherD"; class GatherD : public PrimitiveC { public: - GatherD() : PrimitiveC(kNameGatherD) { InitIOName({"x", "dim", "index"}, {"output"}); } + GatherD() : PrimitiveC(prim::kPrimGatherD->name()) { InitIOName({"x", "dim", "index"}, {"output"}); } ~GatherD() = default; MS_DECLARE_PARENT(GatherD, PrimitiveC); void Init() {} diff --git a/mindspore/core/ops/log1p.h b/mindspore/core/ops/log1p.h index 0483c19e85..58a8a0004e 100644 --- a/mindspore/core/ops/log1p.h +++ b/mindspore/core/ops/log1p.h @@ -27,10 +27,9 @@ namespace mindspore { namespace ops { -constexpr auto kNameLog1p = "Log1p"; class Log1p : public PrimitiveC { public: - Log1p() : PrimitiveC(kNameLog1p) { InitIOName({"x"}, {"y"}); } + Log1p() : PrimitiveC(prim::kPrimLog1p->name()) { InitIOName({"x"}, {"y"}); } ~Log1p() = default; MS_DECLARE_PARENT(Log1p, PrimitiveC); void Init() {} diff --git a/mindspore/core/ops/scalar_summary.h b/mindspore/core/ops/scalar_summary.h index e89c85aca1..c688f3f7b8 100644 --- a/mindspore/core/ops/scalar_summary.h +++ b/mindspore/core/ops/scalar_summary.h @@ -27,10 +27,9 @@ namespace mindspore { namespace ops { -constexpr auto kNameScalarSummary = "ScalarSummary"; class ScalarSummary : public PrimitiveC { public: - ScalarSummary() : PrimitiveC(kNameScalarSummary) {} + ScalarSummary() : PrimitiveC(prim::kPrimScalarSummary->name()) {} ~ScalarSummary() = default; MS_DECLARE_PARENT(ScalarSummary, PrimitiveC); void Init(); diff --git a/mindspore/core/ops/softplus.h b/mindspore/core/ops/softplus.h index cc578e47e4..42bc40518b 100644 --- a/mindspore/core/ops/softplus.h +++ b/mindspore/core/ops/softplus.h @@ -27,10 +27,9 @@ namespace mindspore { namespace ops { -constexpr auto kNameSoftplus = "Softplus"; class Softplus : public PrimitiveC { public: - Softplus() : PrimitiveC(kNameSoftplus) { InitIOName({"x"}, {"output"}); } + Softplus() : PrimitiveC(prim::kPrimSoftplus->name()) { InitIOName({"x"}, {"output"}); } ~Softplus() = default; MS_DECLARE_PARENT(Softplus, PrimitiveC); void Init() {} diff --git a/mindspore/core/ops/tensor_summary.h b/mindspore/core/ops/tensor_summary.h index 61badeadb2..317e22f668 100644 --- a/mindspore/core/ops/tensor_summary.h +++ b/mindspore/core/ops/tensor_summary.h @@ -27,10 +27,9 @@ namespace mindspore { namespace ops { -constexpr auto kNameTensorSummary = "TensorSummary"; class TensorSummary : public PrimitiveC { public: - TensorSummary() : PrimitiveC(kNameTensorSummary) {} + TensorSummary() : PrimitiveC(prim::kPrimTensorSummary->name()) {} ~TensorSummary() = default; MS_DECLARE_PARENT(TensorSummary, PrimitiveC); void Init(); diff --git a/mindspore/core/ops/zeros.h b/mindspore/core/ops/zeros.h index b9afc5d8e2..d0c23bfc32 100644 --- a/mindspore/core/ops/zeros.h +++ b/mindspore/core/ops/zeros.h @@ -27,10 +27,9 @@ namespace mindspore { namespace ops { -constexpr auto kNameZeros = "Zeros"; class Zeros : public PrimitiveC { public: - Zeros() : PrimitiveC(kNameZeros) {} + Zeros() : PrimitiveC(prim::kPrimZeros->name()) {} ~Zeros() = default; MS_DECLARE_PARENT(Zeros, PrimitiveC); void Init() {} diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 314af566e5..0734be086c 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -807,7 +807,7 @@ class Unique(Primitive): self.init_prim_io_names(inputs=['x'], outputs=['output']) -class Gather(PrimitiveWithCheck): +class Gather(Primitive): r""" Returns a slice of the input tensor based on the specified indices and axis. @@ -852,15 +852,6 @@ class Gather(PrimitiveWithCheck): """Initialize index_select""" self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) - def __check__(self, params, indices, axis): - validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) - validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) - validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name) - axis_v = axis['value'] - validator.check_value_type('axis', axis_v, [int], self.name) - rank = len(params['shape']) - validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) - class GatherV2(PrimitiveWithCheck): """