| @@ -77,7 +77,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimBroadcastShape, {InferImplBroadCastShape, nullptr, true}}, | {prim::kPrimBroadcastShape, {InferImplBroadCastShape, nullptr, true}}, | ||||
| {prim::kPrimUnique, {InferImplUnique, nullptr, true}}, | {prim::kPrimUnique, {InferImplUnique, nullptr, true}}, | ||||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, nullptr, true}}, | {prim::kPrimUniqueGrad, {InferImplUniqueGrad, nullptr, true}}, | ||||
| {prim::kPrimGather, {InferImplGatherV2, nullptr, true}}, | |||||
| {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, nullptr, true}}, | {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, nullptr, true}}, | ||||
| {prim::kPrimSparseGatherV2, {InferImplGatherV2, nullptr, true}}, | {prim::kPrimSparseGatherV2, {InferImplGatherV2, nullptr, true}}, | ||||
| {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, nullptr, true}}, | {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, nullptr, true}}, | ||||
| @@ -27,10 +27,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| constexpr auto kNameBroadcastTo = "BroadcastTo"; | |||||
| class BroadcastTo : public PrimitiveC { | class BroadcastTo : public PrimitiveC { | ||||
| public: | public: | ||||
| BroadcastTo() : PrimitiveC(kNameBroadcastTo) {} | |||||
| BroadcastTo() : PrimitiveC(prim::kPrimBroadcastTo->name()) {} | |||||
| ~BroadcastTo() = default; | ~BroadcastTo() = default; | ||||
| MS_DECLARE_PARENT(BroadcastTo, PrimitiveC); | MS_DECLARE_PARENT(BroadcastTo, PrimitiveC); | ||||
| void Init(const std::vector<int64_t> &shape); | void Init(const std::vector<int64_t> &shape); | ||||
| @@ -16,10 +16,79 @@ | |||||
| #include <set> | #include <set> | ||||
| #include <memory> | #include <memory> | ||||
| #include <algorithm> | |||||
| #include "ops/gather.h" | #include "ops/gather.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| REGISTER_PRIMITIVE_C(kNameGather, Gather); | |||||
| // gather | |||||
| AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| const std::string &op_name = primitive->name(); | |||||
| abstract::CheckArgsSize(op_name, input_args, 3); | |||||
| abstract::AbstractTensorPtr params = | |||||
| CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 0); | |||||
| abstract::AbstractTensorPtr indices = | |||||
| CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 1); | |||||
| // check | |||||
| std::set<TypePtr> valid_params_types = {kTensorType}; | |||||
| CheckAndConvertUtils::CheckSubClass("params_type", input_args[0]->BuildType(), valid_params_types, op_name); | |||||
| std::set<TypePtr> int_types = {kInt8, kInt16, kInt32, kInt64}; | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[1]->BuildType(), int_types, op_name); | |||||
| CheckAndConvertUtils::CheckTypeValid("axis_type", input_args[2]->BuildType(), int_types, op_name); | |||||
| bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()); | |||||
| bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty()); | |||||
| int64_t axis_val = 0; | |||||
| // 3rd input is a Tensor when Gather is a dynamic shape operator | |||||
| if (input_args[2]->isa<abstract::AbstractTensor>()) { | |||||
| auto axis = input_args[2]->cast<abstract::AbstractTensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(axis); | |||||
| auto axis_value_ptr = axis->BuildValue(); | |||||
| MS_EXCEPTION_IF_NULL(axis_value_ptr); | |||||
| auto axis_tensor = axis_value_ptr->cast<tensor::TensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(axis_tensor); | |||||
| axis_val = *static_cast<int64_t *>(axis_tensor->data_c()); | |||||
| } else if (input_args[2]->isa<abstract::AbstractScalar>()) { | |||||
| auto axis = input_args[2]->cast<abstract::AbstractScalarPtr>(); | |||||
| axis_val = GetValue<int64_t>(axis->BuildValue()); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Invalid abstract type:" << input_args[2]->type_name(); | |||||
| } | |||||
| auto params_shp = params->shape()->shape(); | |||||
| auto indices_shp = indices->shape()->shape(); | |||||
| auto params_rank = static_cast<int64_t>(params_shp.size()); | |||||
| CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis_val, kIncludeLeft, {-params_rank, params_rank}, op_name); | |||||
| // either inputs or both can be dynamic and computation requires min/max shapes for both | |||||
| ShapeVector param_shp_min = (param_dyn) ? params->shape()->min_shape() : params->shape()->shape(); | |||||
| ShapeVector param_shp_max = (param_dyn) ? params->shape()->max_shape() : params->shape()->shape(); | |||||
| ShapeVector indices_shp_min = (ind_dyn) ? indices->shape()->min_shape() : indices->shape()->shape(); | |||||
| ShapeVector indices_shp_max = (ind_dyn) ? indices->shape()->max_shape() : indices->shape()->shape(); | |||||
| // check axis_val within interval: [-params_rank, params_rank) | |||||
| if (!(-params_rank <= axis_val) || !(axis_val < params_rank)) { | |||||
| MS_LOG(EXCEPTION) << "For Gather - Axis value must be within [ " << -params_rank << ", " << params_rank << " ) " | |||||
| << "Got " << axis_val << "."; | |||||
| } | |||||
| if (axis_val < 0) { | |||||
| axis_val += params_rank; | |||||
| } | |||||
| auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector ¶ms_vec) -> ShapeVector { | |||||
| ShapeVector out_vec; | |||||
| std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec)); | |||||
| copy(ind_vec.begin(), ind_vec.end(), std::back_inserter(out_vec)); | |||||
| copy(params_vec.begin() + axis_val + 1, params_vec.end(), std::back_inserter(out_vec)); | |||||
| return out_vec; | |||||
| }; | |||||
| ShapeVector out_shape = calc_shape(indices_shp, params_shp); | |||||
| if (ind_dyn || param_dyn) { | |||||
| ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min); | |||||
| ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max); | |||||
| return std::make_shared<abstract::AbstractTensor>( | |||||
| params->element(), std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape)); | |||||
| } | |||||
| return std::make_shared<abstract::AbstractTensor>(params->element(), std::make_shared<abstract::Shape>(out_shape)); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Gather, prim::kPrimGather, GatherInfer, nullptr, true); | |||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,7 +29,7 @@ namespace ops { | |||||
| constexpr auto kNameGather = "Gather"; | constexpr auto kNameGather = "Gather"; | ||||
| class Gather : public PrimitiveC { | class Gather : public PrimitiveC { | ||||
| public: | public: | ||||
| Gather() : PrimitiveC(kNameGather) { InitIOName({"x", "dim", "index"}, {"output"}); } | |||||
| Gather() : PrimitiveC(kNameGather) { InitIOName({"param", "indices", "axis"}, {"output"}); } | |||||
| ~Gather() = default; | ~Gather() = default; | ||||
| MS_DECLARE_PARENT(Gather, PrimitiveC); | MS_DECLARE_PARENT(Gather, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| @@ -33,7 +33,9 @@ abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::v | |||||
| auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name); | auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name); | ||||
| int64_t x_rank = x_shape.size(); | int64_t x_rank = x_shape.size(); | ||||
| CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name); | CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name); | ||||
| auto dim_v = GetValue<int64_t>(input_args[1]->BuildValue()); | |||||
| auto value_ptr = input_args[1]->BuildValue(); | |||||
| MS_EXCEPTION_IF_NULL(value_ptr); | |||||
| auto dim_v = GetValue<int64_t>(value_ptr); | |||||
| CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, "negative index_rank", -x_rank, prim_name); | CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, "negative index_rank", -x_rank, prim_name); | ||||
| CheckAndConvertUtils::Check("dim value", dim_v, kLessThan, "index_rank", x_rank, prim_name); | CheckAndConvertUtils::Check("dim value", dim_v, kLessThan, "index_rank", x_rank, prim_name); | ||||
| @@ -27,10 +27,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| constexpr auto kNameGatherD = "GatherD"; | |||||
| class GatherD : public PrimitiveC { | class GatherD : public PrimitiveC { | ||||
| public: | public: | ||||
| GatherD() : PrimitiveC(kNameGatherD) { InitIOName({"x", "dim", "index"}, {"output"}); } | |||||
| GatherD() : PrimitiveC(prim::kPrimGatherD->name()) { InitIOName({"x", "dim", "index"}, {"output"}); } | |||||
| ~GatherD() = default; | ~GatherD() = default; | ||||
| MS_DECLARE_PARENT(GatherD, PrimitiveC); | MS_DECLARE_PARENT(GatherD, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| @@ -27,10 +27,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| constexpr auto kNameLog1p = "Log1p"; | |||||
| class Log1p : public PrimitiveC { | class Log1p : public PrimitiveC { | ||||
| public: | public: | ||||
| Log1p() : PrimitiveC(kNameLog1p) { InitIOName({"x"}, {"y"}); } | |||||
| Log1p() : PrimitiveC(prim::kPrimLog1p->name()) { InitIOName({"x"}, {"y"}); } | |||||
| ~Log1p() = default; | ~Log1p() = default; | ||||
| MS_DECLARE_PARENT(Log1p, PrimitiveC); | MS_DECLARE_PARENT(Log1p, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| @@ -27,10 +27,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| constexpr auto kNameScalarSummary = "ScalarSummary"; | |||||
| class ScalarSummary : public PrimitiveC { | class ScalarSummary : public PrimitiveC { | ||||
| public: | public: | ||||
| ScalarSummary() : PrimitiveC(kNameScalarSummary) {} | |||||
| ScalarSummary() : PrimitiveC(prim::kPrimScalarSummary->name()) {} | |||||
| ~ScalarSummary() = default; | ~ScalarSummary() = default; | ||||
| MS_DECLARE_PARENT(ScalarSummary, PrimitiveC); | MS_DECLARE_PARENT(ScalarSummary, PrimitiveC); | ||||
| void Init(); | void Init(); | ||||
| @@ -27,10 +27,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| constexpr auto kNameSoftplus = "Softplus"; | |||||
| class Softplus : public PrimitiveC { | class Softplus : public PrimitiveC { | ||||
| public: | public: | ||||
| Softplus() : PrimitiveC(kNameSoftplus) { InitIOName({"x"}, {"output"}); } | |||||
| Softplus() : PrimitiveC(prim::kPrimSoftplus->name()) { InitIOName({"x"}, {"output"}); } | |||||
| ~Softplus() = default; | ~Softplus() = default; | ||||
| MS_DECLARE_PARENT(Softplus, PrimitiveC); | MS_DECLARE_PARENT(Softplus, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| @@ -27,10 +27,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| constexpr auto kNameTensorSummary = "TensorSummary"; | |||||
| class TensorSummary : public PrimitiveC { | class TensorSummary : public PrimitiveC { | ||||
| public: | public: | ||||
| TensorSummary() : PrimitiveC(kNameTensorSummary) {} | |||||
| TensorSummary() : PrimitiveC(prim::kPrimTensorSummary->name()) {} | |||||
| ~TensorSummary() = default; | ~TensorSummary() = default; | ||||
| MS_DECLARE_PARENT(TensorSummary, PrimitiveC); | MS_DECLARE_PARENT(TensorSummary, PrimitiveC); | ||||
| void Init(); | void Init(); | ||||
| @@ -27,10 +27,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| constexpr auto kNameZeros = "Zeros"; | |||||
| class Zeros : public PrimitiveC { | class Zeros : public PrimitiveC { | ||||
| public: | public: | ||||
| Zeros() : PrimitiveC(kNameZeros) {} | |||||
| Zeros() : PrimitiveC(prim::kPrimZeros->name()) {} | |||||
| ~Zeros() = default; | ~Zeros() = default; | ||||
| MS_DECLARE_PARENT(Zeros, PrimitiveC); | MS_DECLARE_PARENT(Zeros, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| @@ -807,7 +807,7 @@ class Unique(Primitive): | |||||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | self.init_prim_io_names(inputs=['x'], outputs=['output']) | ||||
| class Gather(PrimitiveWithCheck): | |||||
| class Gather(Primitive): | |||||
| r""" | r""" | ||||
| Returns a slice of the input tensor based on the specified indices and axis. | Returns a slice of the input tensor based on the specified indices and axis. | ||||
| @@ -852,15 +852,6 @@ class Gather(PrimitiveWithCheck): | |||||
| """Initialize index_select""" | """Initialize index_select""" | ||||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | ||||
| def __check__(self, params, indices, axis): | |||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | |||||
| validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name) | |||||
| axis_v = axis['value'] | |||||
| validator.check_value_type('axis', axis_v, [int], self.name) | |||||
| rank = len(params['shape']) | |||||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||||
| class GatherV2(PrimitiveWithCheck): | class GatherV2(PrimitiveWithCheck): | ||||
| """ | """ | ||||