|
|
|
@@ -38,6 +38,8 @@ int UnsortedSegmentSumCPUKernel::Init() { |
|
|
|
auto input_shape = in_tensors_.at(0)->shape(); |
|
|
|
auto segment_ids_shape = in_tensors_.at(1)->shape(); |
|
|
|
auto output_shape = out_tensors_.at(0)->shape(); |
|
|
|
unit_num_ = 1; |
|
|
|
input_dim1_ = 1; |
|
|
|
for (size_t i = 0; i < input_shape.size(); ++i) { |
|
|
|
unit_num_ *= input_shape[i]; |
|
|
|
if (i >= segment_ids_shape.size()) { |
|
|
|
@@ -45,6 +47,7 @@ int UnsortedSegmentSumCPUKernel::Init() { |
|
|
|
} |
|
|
|
} |
|
|
|
output_dim0_ = output_shape[0]; |
|
|
|
output_dim1_ = 1; |
|
|
|
for (size_t j = 1; j < output_shape.size(); j++) { |
|
|
|
output_dim1_ *= output_shape[j]; |
|
|
|
} |
|
|
|
|