Browse Source

gather op infer

pull/15211/head
simson 4 years ago
parent
commit
b43bf90666
12 changed files with 82 additions and 28 deletions
  1. +0
    -1
      mindspore/core/abstract/primitive_infer_map.cc
  2. +1
    -2
      mindspore/core/ops/broadcast_to.h
  3. +70
    -1
      mindspore/core/ops/gather.cc
  4. +1
    -1
      mindspore/core/ops/gather.h
  5. +3
    -1
      mindspore/core/ops/gather_d.cc
  6. +1
    -2
      mindspore/core/ops/gather_d.h
  7. +1
    -2
      mindspore/core/ops/log1p.h
  8. +1
    -2
      mindspore/core/ops/scalar_summary.h
  9. +1
    -2
      mindspore/core/ops/softplus.h
  10. +1
    -2
      mindspore/core/ops/tensor_summary.h
  11. +1
    -2
      mindspore/core/ops/zeros.h
  12. +1
    -10
      mindspore/ops/operations/array_ops.py

+ 0
- 1
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -77,7 +77,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, nullptr, true}},
{prim::kPrimUnique, {InferImplUnique, nullptr, true}},
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, nullptr, true}},
{prim::kPrimGather, {InferImplGatherV2, nullptr, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, nullptr, true}},
{prim::kPrimSparseGatherV2, {InferImplGatherV2, nullptr, true}},
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, nullptr, true}},


+ 1
- 2
mindspore/core/ops/broadcast_to.h View File

@@ -27,10 +27,9 @@

namespace mindspore {
namespace ops {
constexpr auto kNameBroadcastTo = "BroadcastTo";
class BroadcastTo : public PrimitiveC {
public:
BroadcastTo() : PrimitiveC(kNameBroadcastTo) {}
BroadcastTo() : PrimitiveC(prim::kPrimBroadcastTo->name()) {}
~BroadcastTo() = default;
MS_DECLARE_PARENT(BroadcastTo, PrimitiveC);
void Init(const std::vector<int64_t> &shape);


+ 70
- 1
mindspore/core/ops/gather.cc View File

@@ -16,10 +16,79 @@

#include <set>
#include <memory>
#include <algorithm>
#include "ops/gather.h"

namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameGather, Gather);
// gather
AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const std::string &op_name = primitive->name();
abstract::CheckArgsSize(op_name, input_args, 3);
abstract::AbstractTensorPtr params =
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 0);
abstract::AbstractTensorPtr indices =
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 1);
// check
std::set<TypePtr> valid_params_types = {kTensorType};
CheckAndConvertUtils::CheckSubClass("params_type", input_args[0]->BuildType(), valid_params_types, op_name);
std::set<TypePtr> int_types = {kInt8, kInt16, kInt32, kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[1]->BuildType(), int_types, op_name);
CheckAndConvertUtils::CheckTypeValid("axis_type", input_args[2]->BuildType(), int_types, op_name);

bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty());
bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty());
int64_t axis_val = 0;
// 3rd input is a Tensor when Gather is a dynamic shape operator
if (input_args[2]->isa<abstract::AbstractTensor>()) {
auto axis = input_args[2]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(axis);
auto axis_value_ptr = axis->BuildValue();
MS_EXCEPTION_IF_NULL(axis_value_ptr);
auto axis_tensor = axis_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(axis_tensor);
axis_val = *static_cast<int64_t *>(axis_tensor->data_c());
} else if (input_args[2]->isa<abstract::AbstractScalar>()) {
auto axis = input_args[2]->cast<abstract::AbstractScalarPtr>();
axis_val = GetValue<int64_t>(axis->BuildValue());
} else {
MS_LOG(EXCEPTION) << "Invalid abstract type:" << input_args[2]->type_name();
}
auto params_shp = params->shape()->shape();
auto indices_shp = indices->shape()->shape();
auto params_rank = static_cast<int64_t>(params_shp.size());
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis_val, kIncludeLeft, {-params_rank, params_rank}, op_name);
// either inputs or both can be dynamic and computation requires min/max shapes for both
ShapeVector param_shp_min = (param_dyn) ? params->shape()->min_shape() : params->shape()->shape();
ShapeVector param_shp_max = (param_dyn) ? params->shape()->max_shape() : params->shape()->shape();
ShapeVector indices_shp_min = (ind_dyn) ? indices->shape()->min_shape() : indices->shape()->shape();
ShapeVector indices_shp_max = (ind_dyn) ? indices->shape()->max_shape() : indices->shape()->shape();
// check axis_val within interval: [-params_rank, params_rank)
if (!(-params_rank <= axis_val) || !(axis_val < params_rank)) {
MS_LOG(EXCEPTION) << "For Gather - Axis value must be within [ " << -params_rank << ", " << params_rank << " ) "
<< "Got " << axis_val << ".";
}
if (axis_val < 0) {
axis_val += params_rank;
}
auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector &params_vec) -> ShapeVector {
ShapeVector out_vec;
std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec));
copy(ind_vec.begin(), ind_vec.end(), std::back_inserter(out_vec));
copy(params_vec.begin() + axis_val + 1, params_vec.end(), std::back_inserter(out_vec));
return out_vec;
};
ShapeVector out_shape = calc_shape(indices_shp, params_shp);
if (ind_dyn || param_dyn) {
ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min);
ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max);
return std::make_shared<abstract::AbstractTensor>(
params->element(), std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape));
}
return std::make_shared<abstract::AbstractTensor>(params->element(), std::make_shared<abstract::Shape>(out_shape));
}
REGISTER_PRIMITIVE_EVAL_IMPL(Gather, prim::kPrimGather, GatherInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

+ 1
- 1
mindspore/core/ops/gather.h View File

@@ -29,7 +29,7 @@ namespace ops {
constexpr auto kNameGather = "Gather";
class Gather : public PrimitiveC {
public:
Gather() : PrimitiveC(kNameGather) { InitIOName({"x", "dim", "index"}, {"output"}); }
Gather() : PrimitiveC(kNameGather) { InitIOName({"param", "indices", "axis"}, {"output"}); }
~Gather() = default;
MS_DECLARE_PARENT(Gather, PrimitiveC);
void Init() {}


+ 3
- 1
mindspore/core/ops/gather_d.cc View File

@@ -33,7 +33,9 @@ abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::v
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name);
int64_t x_rank = x_shape.size();
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name);
auto dim_v = GetValue<int64_t>(input_args[1]->BuildValue());
auto value_ptr = input_args[1]->BuildValue();
MS_EXCEPTION_IF_NULL(value_ptr);
auto dim_v = GetValue<int64_t>(value_ptr);
CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, "negative index_rank", -x_rank, prim_name);
CheckAndConvertUtils::Check("dim value", dim_v, kLessThan, "index_rank", x_rank, prim_name);



+ 1
- 2
mindspore/core/ops/gather_d.h View File

@@ -27,10 +27,9 @@

namespace mindspore {
namespace ops {
constexpr auto kNameGatherD = "GatherD";
class GatherD : public PrimitiveC {
public:
GatherD() : PrimitiveC(kNameGatherD) { InitIOName({"x", "dim", "index"}, {"output"}); }
GatherD() : PrimitiveC(prim::kPrimGatherD->name()) { InitIOName({"x", "dim", "index"}, {"output"}); }
~GatherD() = default;
MS_DECLARE_PARENT(GatherD, PrimitiveC);
void Init() {}


+ 1
- 2
mindspore/core/ops/log1p.h View File

@@ -27,10 +27,9 @@

namespace mindspore {
namespace ops {
constexpr auto kNameLog1p = "Log1p";
class Log1p : public PrimitiveC {
public:
Log1p() : PrimitiveC(kNameLog1p) { InitIOName({"x"}, {"y"}); }
Log1p() : PrimitiveC(prim::kPrimLog1p->name()) { InitIOName({"x"}, {"y"}); }
~Log1p() = default;
MS_DECLARE_PARENT(Log1p, PrimitiveC);
void Init() {}


+ 1
- 2
mindspore/core/ops/scalar_summary.h View File

@@ -27,10 +27,9 @@

namespace mindspore {
namespace ops {
constexpr auto kNameScalarSummary = "ScalarSummary";
class ScalarSummary : public PrimitiveC {
public:
ScalarSummary() : PrimitiveC(kNameScalarSummary) {}
ScalarSummary() : PrimitiveC(prim::kPrimScalarSummary->name()) {}
~ScalarSummary() = default;
MS_DECLARE_PARENT(ScalarSummary, PrimitiveC);
void Init();


+ 1
- 2
mindspore/core/ops/softplus.h View File

@@ -27,10 +27,9 @@

namespace mindspore {
namespace ops {
constexpr auto kNameSoftplus = "Softplus";
class Softplus : public PrimitiveC {
public:
Softplus() : PrimitiveC(kNameSoftplus) { InitIOName({"x"}, {"output"}); }
Softplus() : PrimitiveC(prim::kPrimSoftplus->name()) { InitIOName({"x"}, {"output"}); }
~Softplus() = default;
MS_DECLARE_PARENT(Softplus, PrimitiveC);
void Init() {}


+ 1
- 2
mindspore/core/ops/tensor_summary.h View File

@@ -27,10 +27,9 @@

namespace mindspore {
namespace ops {
constexpr auto kNameTensorSummary = "TensorSummary";
class TensorSummary : public PrimitiveC {
public:
TensorSummary() : PrimitiveC(kNameTensorSummary) {}
TensorSummary() : PrimitiveC(prim::kPrimTensorSummary->name()) {}
~TensorSummary() = default;
MS_DECLARE_PARENT(TensorSummary, PrimitiveC);
void Init();


+ 1
- 2
mindspore/core/ops/zeros.h View File

@@ -27,10 +27,9 @@

namespace mindspore {
namespace ops {
constexpr auto kNameZeros = "Zeros";
class Zeros : public PrimitiveC {
public:
Zeros() : PrimitiveC(kNameZeros) {}
Zeros() : PrimitiveC(prim::kPrimZeros->name()) {}
~Zeros() = default;
MS_DECLARE_PARENT(Zeros, PrimitiveC);
void Init() {}


+ 1
- 10
mindspore/ops/operations/array_ops.py View File

@@ -807,7 +807,7 @@ class Unique(Primitive):
self.init_prim_io_names(inputs=['x'], outputs=['output'])


class Gather(PrimitiveWithCheck):
class Gather(Primitive):
r"""
Returns a slice of the input tensor based on the specified indices and axis.

@@ -852,15 +852,6 @@ class Gather(PrimitiveWithCheck):
"""Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])

def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
axis_v = axis['value']
validator.check_value_type('axis', axis_v, [int], self.name)
rank = len(params['shape'])
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)


class GatherV2(PrimitiveWithCheck):
"""


Loading…
Cancel
Save