|
|
|
@@ -16,7 +16,6 @@ |
|
|
|
|
|
|
|
#include "backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h" |
|
|
|
#include <string> |
|
|
|
#include <cmath> |
|
|
|
#include "runtime/device/cpu/cpu_device_address.h" |
|
|
|
#include "common/thread_pool.h" |
|
|
|
|
|
|
|
@@ -79,37 +78,17 @@ bool UnsortedSegmentSumCPUKernel::LaunchKernel(const std::vector<AddressPtr> &in |
|
|
|
MS_LOG(ERROR) << "Output buff memset fail. ret:" << ret; |
|
|
|
return false; |
|
|
|
} |
|
|
|
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); |
|
|
|
if (unit_num_ < thread_num) { |
|
|
|
thread_num = unit_num_; |
|
|
|
} |
|
|
|
std::vector<common::Task> tasks; |
|
|
|
tasks.reserve(thread_num); |
|
|
|
auto task = [&](size_t start, size_t end) { |
|
|
|
for (size_t c = 0; c < ceil(static_cast<double>(unit_num_) / thread_num); ++c) { |
|
|
|
if (c * thread_num + start >= unit_num_) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
size_t i = c * thread_num + start; |
|
|
|
size_t j = i / input_dim1_; |
|
|
|
size_t k = i % input_dim1_; |
|
|
|
for (size_t i = 0; i < unit_num_; ++i) { |
|
|
|
size_t j = i / input_dim1_; |
|
|
|
size_t k = i % input_dim1_; |
|
|
|
|
|
|
|
T index = indices_addr[j]; |
|
|
|
if (index < 0 || index >= SizeToInt(output_dim0_)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
size_t output_index = index * output_dim1_ + k; |
|
|
|
output_addr[output_index] += input_addr[i]; |
|
|
|
T index = indices_addr[j]; |
|
|
|
if (index < 0 || index >= SizeToInt(output_dim0_)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
}; |
|
|
|
for (size_t t = 0; t < thread_num; ++t) { |
|
|
|
auto block = [&, t]() { |
|
|
|
task(t, t + 1); |
|
|
|
return common::SUCCESS; |
|
|
|
}; |
|
|
|
tasks.emplace_back(block); |
|
|
|
size_t output_index = index * output_dim1_ + k; |
|
|
|
output_addr[output_index] += input_addr[i]; |
|
|
|
} |
|
|
|
common::ThreadPool::GetInstance().SyncRun(tasks); |
|
|
|
return true; |
|
|
|
} |
|
|
|
} // namespace kernel |
|
|
|
|