|
|
|
@@ -226,42 +226,18 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri |
|
|
|
auto segment_ids_shape = segment_ids->shape()->shape(); |
|
|
|
(void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentSum should be %s"); |
|
|
|
(void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for UnsortedSegmentSum should be %s"); |
|
|
|
// check if dynamic shape |
|
|
|
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); |
|
|
|
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); // check if dynamic shape |
|
|
|
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty()); |
|
|
|
bool op_is_dynamic = x_is_dyn || ids_is_dyn; |
|
|
|
auto x_shape = x->shape()->shape(); |
|
|
|
ShapeVector shape; |
|
|
|
int64_t num_segments_value = 0; |
|
|
|
if (args_spec_list[2]->isa<AbstractTensor>()) { // num_segments is Tensor |
|
|
|
auto num_segments = args_spec_list[2]->cast<AbstractTensorPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(num_segments); |
|
|
|
auto num_segments_value_ptr = num_segments->BuildValue(); |
|
|
|
MS_EXCEPTION_IF_NULL(num_segments_value_ptr); |
|
|
|
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(num_segments_tensor); |
|
|
|
if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { |
|
|
|
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c()); |
|
|
|
} else { |
|
|
|
num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c()); |
|
|
|
} |
|
|
|
} else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar |
|
|
|
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2); |
|
|
|
if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { |
|
|
|
num_segments_value = GetValue<int64_t>(num_segments->BuildValue()); |
|
|
|
} else { |
|
|
|
num_segments_value = GetValue<int32_t>(num_segments->BuildValue()); |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentSum"; |
|
|
|
} |
|
|
|
int64_t num_segments_value = GetUnsortedSegmentOpScalarArg(args_spec_list, op_name); |
|
|
|
if (num_segments_value <= 0) { |
|
|
|
MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentSum"; |
|
|
|
} |
|
|
|
shape.emplace_back(num_segments_value); |
|
|
|
shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end()); |
|
|
|
// dims check |
|
|
|
if (!op_is_dynamic) { |
|
|
|
if (!op_is_dynamic) { // not dynamic |
|
|
|
for (size_t i = 0; i < segment_ids_shape.size(); i++) { |
|
|
|
if (x_shape[i] != segment_ids_shape[i]) { |
|
|
|
MS_LOG(EXCEPTION) << "Shape values of segments_ids must match with corresponding x shape values"; |
|
|
|
@@ -269,16 +245,14 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri |
|
|
|
} |
|
|
|
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape)); |
|
|
|
} |
|
|
|
// is dynamic |
|
|
|
ShapeVector min_shape; |
|
|
|
ShapeVector max_shape; |
|
|
|
min_shape.emplace_back(num_segments_value); |
|
|
|
max_shape.emplace_back(num_segments_value); |
|
|
|
// only run validation if shape values are known |
|
|
|
bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; }); |
|
|
|
bool ids_any_shape = |
|
|
|
std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; }); |
|
|
|
if (!x_any_shape && !ids_any_shape) { |
|
|
|
if (!x_any_shape && !ids_any_shape) { // only validate when shapes fully known |
|
|
|
for (size_t i = 0; i < segment_ids_shape.size(); i++) { |
|
|
|
if (x_shape[i] != segment_ids_shape[i]) { |
|
|
|
MS_LOG(EXCEPTION) << "Shape values of segments_ids must match with corresponding x shape values"; |
|
|
|
@@ -307,52 +281,27 @@ AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const Pri |
|
|
|
auto segment_ids_shape = segment_ids->shape()->shape(); |
|
|
|
(void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMax should be %s"); |
|
|
|
(void)CheckTensorDType(segment_ids, {kInt32, kInt64}, "Input 1 (segment_ids) for UnsortedSegmentMax should be %s"); |
|
|
|
// check if dynamic shape |
|
|
|
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); |
|
|
|
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); // check if dynamic |
|
|
|
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty()); |
|
|
|
bool op_is_dynamic = x_is_dyn || ids_is_dyn; |
|
|
|
auto x_shape = x->shape()->shape(); |
|
|
|
ShapeVector shape; |
|
|
|
int64_t num_segments_value = 0; |
|
|
|
if (args_spec_list[2]->isa<AbstractTensor>()) { // num_segments is Tensor |
|
|
|
auto num_segments = args_spec_list[2]->cast<AbstractTensorPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(num_segments); |
|
|
|
auto num_segments_value_ptr = num_segments->BuildValue(); |
|
|
|
MS_EXCEPTION_IF_NULL(num_segments_value_ptr); |
|
|
|
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(num_segments_tensor); |
|
|
|
if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { |
|
|
|
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c()); |
|
|
|
} else { |
|
|
|
num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c()); |
|
|
|
} |
|
|
|
} else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar |
|
|
|
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2); |
|
|
|
if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { |
|
|
|
num_segments_value = GetValue<int64_t>(num_segments->BuildValue()); |
|
|
|
} else { |
|
|
|
num_segments_value = GetValue<int32_t>(num_segments->BuildValue()); |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMax"; |
|
|
|
} |
|
|
|
int64_t num_segments_value = GetUnsortedSegmentOpScalarArg(args_spec_list, op_name); |
|
|
|
if (num_segments_value <= 0) { |
|
|
|
MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentMax"; |
|
|
|
} |
|
|
|
shape.emplace_back(num_segments_value); |
|
|
|
shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end()); |
|
|
|
if (!op_is_dynamic) { |
|
|
|
if (!op_is_dynamic) { // not dynamic |
|
|
|
if (x_shape[0] != segment_ids_shape[0]) { |
|
|
|
MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMax"; |
|
|
|
} |
|
|
|
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape)); |
|
|
|
} |
|
|
|
// is dynamic |
|
|
|
ShapeVector min_shape; |
|
|
|
ShapeVector max_shape; |
|
|
|
min_shape.emplace_back(num_segments_value); |
|
|
|
max_shape.emplace_back(num_segments_value); |
|
|
|
// only run validation if shape values are known |
|
|
|
bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; }); |
|
|
|
bool ids_any_shape = |
|
|
|
std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; }); |
|
|
|
@@ -383,56 +332,31 @@ AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const Pri |
|
|
|
auto segment_ids_shape = segment_ids->shape()->shape(); |
|
|
|
(void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMin should be %s"); |
|
|
|
(void)CheckTensorDType(segment_ids, {kInt32}, "Input 1 (segment_ids) for UnsortedSegmentMin should be %s"); |
|
|
|
// check if dynamic shape |
|
|
|
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); |
|
|
|
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty()); // check if dynamic shape |
|
|
|
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty()); |
|
|
|
bool op_is_dynamic = x_is_dyn || ids_is_dyn; |
|
|
|
auto x_shape = x->shape()->shape(); |
|
|
|
ShapeVector shape; |
|
|
|
int64_t num_segments_value = 0; |
|
|
|
if (args_spec_list[2]->isa<AbstractTensor>()) { // num_segments is Tensor |
|
|
|
auto num_segments = args_spec_list[2]->cast<AbstractTensorPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(num_segments); |
|
|
|
auto num_segments_value_ptr = num_segments->BuildValue(); |
|
|
|
MS_EXCEPTION_IF_NULL(num_segments_value_ptr); |
|
|
|
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(num_segments_tensor); |
|
|
|
if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { |
|
|
|
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c()); |
|
|
|
} else { |
|
|
|
num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c()); |
|
|
|
} |
|
|
|
} else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar |
|
|
|
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2); |
|
|
|
if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { |
|
|
|
num_segments_value = GetValue<int64_t>(num_segments->BuildValue()); |
|
|
|
} else { |
|
|
|
num_segments_value = GetValue<int32_t>(num_segments->BuildValue()); |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMin"; |
|
|
|
} |
|
|
|
int64_t num_segments_value = GetUnsortedSegmentOpScalarArg(args_spec_list, op_name); |
|
|
|
if (num_segments_value <= 0) { |
|
|
|
MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentMin"; |
|
|
|
} |
|
|
|
shape.emplace_back(num_segments_value); |
|
|
|
shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end()); |
|
|
|
if (!op_is_dynamic) { |
|
|
|
if (!op_is_dynamic) { // not dynamic |
|
|
|
if (x_shape[0] != segment_ids_shape[0]) { |
|
|
|
MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMin"; |
|
|
|
} |
|
|
|
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape)); |
|
|
|
} |
|
|
|
// is dynamic |
|
|
|
ShapeVector min_shape; |
|
|
|
ShapeVector max_shape; |
|
|
|
min_shape.emplace_back(num_segments_value); |
|
|
|
max_shape.emplace_back(num_segments_value); |
|
|
|
// only run validation if shape values are known |
|
|
|
bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; }); |
|
|
|
bool ids_any_shape = |
|
|
|
std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; }); |
|
|
|
if (!x_any_shape && !ids_any_shape) { |
|
|
|
if (!x_any_shape && !ids_any_shape) { // only validate when shapes fully known |
|
|
|
if (x_shape[0] != segment_ids_shape[0]) { |
|
|
|
MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMin"; |
|
|
|
} |
|
|
|
|