Browse Source

!10089 [MS][DynamicOP][CI Alarm] Updated InferImpl functions to remove duplicate code and fix 1 CI alarm

From: @danishnxt
Reviewed-by: @tom__chen
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
7874877568
3 changed files with 41 additions and 87 deletions
  1. +11
    -87
      mindspore/core/abstract/prim_arrays.cc
  2. +27
    -0
      mindspore/core/abstract/utils.cc
  3. +3
    -0
      mindspore/core/abstract/utils.h

+ 11
- 87
mindspore/core/abstract/prim_arrays.cc View File

@@ -226,42 +226,18 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri
auto segment_ids_shape = segment_ids->shape()->shape();
(void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentSum should be %s");
(void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for UnsortedSegmentSum should be %s");
// check if dynamic shape
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty());
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); // check if dynamic shape
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
bool op_is_dynamic = x_is_dyn || ids_is_dyn;
auto x_shape = x->shape()->shape();
ShapeVector shape;
int64_t num_segments_value = 0;
if (args_spec_list[2]->isa<AbstractTensor>()) { // num_segments is Tensor
auto num_segments = args_spec_list[2]->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments);
auto num_segments_value_ptr = num_segments->BuildValue();
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments_tensor);
if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
} else {
num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c());
}
} else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2);
if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
} else {
num_segments_value = GetValue<int32_t>(num_segments->BuildValue());
}
} else {
MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentSum";
}
int64_t num_segments_value = GetUnsortedSegmentOpScalarArg(args_spec_list, op_name);
if (num_segments_value <= 0) {
MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentSum";
}
shape.emplace_back(num_segments_value);
shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end());
// dims check
if (!op_is_dynamic) {
if (!op_is_dynamic) { // not dynamic
for (size_t i = 0; i < segment_ids_shape.size(); i++) {
if (x_shape[i] != segment_ids_shape[i]) {
MS_LOG(EXCEPTION) << "Shape values of segments_ids must match with corresponding x shape values";
@@ -269,16 +245,14 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri
}
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
}
// is dynamic
ShapeVector min_shape;
ShapeVector max_shape;
min_shape.emplace_back(num_segments_value);
max_shape.emplace_back(num_segments_value);
// only run validation if shape values are known
bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
bool ids_any_shape =
std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
if (!x_any_shape && !ids_any_shape) {
if (!x_any_shape && !ids_any_shape) { // only validate when shapes fully known
for (size_t i = 0; i < segment_ids_shape.size(); i++) {
if (x_shape[i] != segment_ids_shape[i]) {
MS_LOG(EXCEPTION) << "Shape values of segments_ids must match with corresponding x shape values";
@@ -307,52 +281,27 @@ AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const Pri
auto segment_ids_shape = segment_ids->shape()->shape();
(void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMax should be %s");
(void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for UnsortedSegmentMax should be %s");
// check if dynamic shape
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty());
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); // check if dynamic
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
bool op_is_dynamic = x_is_dyn || ids_is_dyn;
auto x_shape = x->shape()->shape();
ShapeVector shape;
int64_t num_segments_value = 0;
if (args_spec_list[2]->isa<AbstractTensor>()) { // num_segments is Tensor
auto num_segments = args_spec_list[2]->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments);
auto num_segments_value_ptr = num_segments->BuildValue();
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments_tensor);
if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
} else {
num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c());
}
} else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2);
if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
} else {
num_segments_value = GetValue<int32_t>(num_segments->BuildValue());
}
} else {
MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMax";
}
int64_t num_segments_value = GetUnsortedSegmentOpScalarArg(args_spec_list, op_name);
if (num_segments_value <= 0) {
MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentMax";
}
shape.emplace_back(num_segments_value);
shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end());
if (!op_is_dynamic) {
if (!op_is_dynamic) { // not dynamic
if (x_shape[0] != segment_ids_shape[0]) {
MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMax";
}
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
}
// is dynamic
ShapeVector min_shape;
ShapeVector max_shape;
min_shape.emplace_back(num_segments_value);
max_shape.emplace_back(num_segments_value);
// only run validation if shape values are known
bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
bool ids_any_shape =
std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
@@ -383,56 +332,31 @@ AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const Pri
auto segment_ids_shape = segment_ids->shape()->shape();
(void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMin should be %s");
(void)CheckTensorDType(segment_ids, {kInt32}, "Input 1 (segment_ids) for UnsortedSegmentMin should be %s");
// check if dynamic shape
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty());
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); // check if dynamic shape
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
bool op_is_dynamic = x_is_dyn || ids_is_dyn;
auto x_shape = x->shape()->shape();
ShapeVector shape;
int64_t num_segments_value = 0;
if (args_spec_list[2]->isa<AbstractTensor>()) { // num_segments is Tensor
auto num_segments = args_spec_list[2]->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments);
auto num_segments_value_ptr = num_segments->BuildValue();
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments_tensor);
if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
} else {
num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c());
}
} else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2);
if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
} else {
num_segments_value = GetValue<int32_t>(num_segments->BuildValue());
}
} else {
MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMin";
}
int64_t num_segments_value = GetUnsortedSegmentOpScalarArg(args_spec_list, op_name);
if (num_segments_value <= 0) {
MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentMin";
}
shape.emplace_back(num_segments_value);
shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end());
if (!op_is_dynamic) {
if (!op_is_dynamic) { // not dynamic
if (x_shape[0] != segment_ids_shape[0]) {
MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMin";
}
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
}
// is dynamic
ShapeVector min_shape;
ShapeVector max_shape;
min_shape.emplace_back(num_segments_value);
max_shape.emplace_back(num_segments_value);
// only run validation if shape values are known
bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
bool ids_any_shape =
std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
if (!x_any_shape && !ids_any_shape) {
if (!x_any_shape && !ids_any_shape) { // only validate when shapes fully known
if (x_shape[0] != segment_ids_shape[0]) {
MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMin";
}


+ 27
- 0
mindspore/core/abstract/utils.cc View File

@@ -315,5 +315,32 @@ void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVec
*min_shape = (*min_shape).empty() ? shape : *min_shape;
*max_shape = (*max_shape).empty() ? shape : *max_shape;
}

int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, const std::string &op_name) {
int64_t num_segments_value = 0;
if (args_spec_list[2]->isa<AbstractTensor>()) { // num_segments is Tensor
auto num_segments = args_spec_list[2]->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments);
auto num_segments_value_ptr = num_segments->BuildValue();
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments_tensor);
if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
} else {
num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c());
}
} else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2);
if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
} else {
num_segments_value = GetValue<int32_t>(num_segments->BuildValue());
}
} else {
MS_LOG(EXCEPTION) << "num_segments incorrect type in " << op_name;
}
return num_segments_value;
}
} // namespace abstract
} // namespace mindspore

+ 3
- 0
mindspore/core/abstract/utils.h View File

@@ -60,6 +60,9 @@ ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tenso
// Check dynamic shape routine
void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);

// Get 3rd argument for UnsortedSegmentOps' inferImpl function
int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, const std::string &op_name);

} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_UTILS_H_

Loading…
Cancel
Save