|
|
|
@@ -16,10 +16,39 @@ |
|
|
|
|
|
|
|
#include "backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h" |
|
|
|
#include <string> |
|
|
|
#include <thread> |
|
|
|
#include "runtime/device/cpu/cpu_device_address.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
namespace { |
|
|
|
template <typename T> |
|
|
|
void Compute(const ComputeParams<T> *params, const size_t start, const size_t end) { |
|
|
|
MS_EXCEPTION_IF_NULL(params); |
|
|
|
T *x = params->x_; |
|
|
|
int *indices = params->indices_; |
|
|
|
T *updates = params->updates_; |
|
|
|
std::vector<int> *out_strides = params->out_strides_; |
|
|
|
MS_EXCEPTION_IF_NULL(out_strides); |
|
|
|
|
|
|
|
for (size_t i = start; i < end; ++i) { |
|
|
|
int offset = 0; |
|
|
|
for (int j = 0; j < params->indices_unit_rank_; ++j) { |
|
|
|
auto index = indices[i * params->indices_unit_rank_ + j]; |
|
|
|
if (index < 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Error, Indices exist element which less than 0. element=" << index; |
|
|
|
} |
|
|
|
offset += index * out_strides->at(j) * params->unit_size_; |
|
|
|
} |
|
|
|
auto ret = |
|
|
|
memcpy_s(x + offset, params->x_mem_size_, updates + params->unit_size_ * i, params->unit_size_ * sizeof(T)); |
|
|
|
if (ret != 0) { |
|
|
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
Check(kernel_node); |
|
|
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); |
|
|
|
@@ -46,9 +75,9 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
unit_size_ *= SizeToInt(updates_shape[i]); |
|
|
|
} |
|
|
|
num_units_ = 1; |
|
|
|
num_units_ *= SizeToInt(updates_shape[indices_shape.size() - 2]); |
|
|
|
num_units_ *= updates_shape[indices_shape.size() - 2]; |
|
|
|
for (int i = SizeToInt(indices_shape.size()) - 3; i >= 0; i--) { |
|
|
|
num_units_ *= SizeToInt(updates_shape[i]); |
|
|
|
num_units_ *= updates_shape[i]; |
|
|
|
} |
|
|
|
int out_stride = 1; |
|
|
|
out_strides_.push_back(out_stride); |
|
|
|
@@ -56,8 +85,6 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
out_stride *= shape[i + 1]; |
|
|
|
out_strides_.push_back(out_stride); |
|
|
|
} |
|
|
|
shape_ = shape; |
|
|
|
output_unit_offsets_.reserve(num_units_); |
|
|
|
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -79,29 +106,29 @@ template <typename T> |
|
|
|
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
const std::vector<kernel::AddressPtr> &outputs) { |
|
|
|
auto x = reinterpret_cast<T *>(inputs[0]->addr); |
|
|
|
auto indices = reinterpret_cast<int *>(inputs[1]->addr); |
|
|
|
auto updates = reinterpret_cast<T *>(inputs[2]->addr); |
|
|
|
ComputeParams<T> params; |
|
|
|
params.x_ = x; |
|
|
|
params.indices_ = reinterpret_cast<int *>(inputs[1]->addr); |
|
|
|
params.updates_ = reinterpret_cast<T *>(inputs[2]->addr); |
|
|
|
params.x_mem_size_ = inputs[0]->size; |
|
|
|
params.unit_size_ = unit_size_; |
|
|
|
params.indices_unit_rank_ = indices_unit_rank_; |
|
|
|
params.out_strides_ = &out_strides_; |
|
|
|
|
|
|
|
for (int i = 0; i < num_units_; ++i) { |
|
|
|
int offset = 0; |
|
|
|
for (int j = 0; j < indices_unit_rank_; ++j) { |
|
|
|
auto index = indices[i * indices_unit_rank_ + j]; |
|
|
|
if (index < 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Error, Indices exist element which less than 0. element=" << index; |
|
|
|
} |
|
|
|
offset += index * out_strides_[j] * unit_size_; |
|
|
|
} |
|
|
|
output_unit_offsets_[i] = offset; |
|
|
|
const size_t thread_num = 24; |
|
|
|
std::vector<std::thread> threads; |
|
|
|
threads.reserve(thread_num); |
|
|
|
size_t start = 0; |
|
|
|
size_t once_compute_size = (num_units_ + thread_num - 1) / thread_num; |
|
|
|
while (start < num_units_) { |
|
|
|
size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size); |
|
|
|
threads.emplace_back(std::thread(Compute<T>, ¶ms, start, end)); |
|
|
|
start += once_compute_size; |
|
|
|
} |
|
|
|
|
|
|
|
auto mem_size = inputs[0]->size; |
|
|
|
for (int i = 0; i < num_units_; i++) { |
|
|
|
auto ret = memcpy_s(x + output_unit_offsets_[i], mem_size, updates + unit_size_ * i, unit_size_ * sizeof(T)); |
|
|
|
if (ret != 0) { |
|
|
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < threads.size(); ++i) { |
|
|
|
threads[i].join(); |
|
|
|
} |
|
|
|
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, mem_size); |
|
|
|
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size); |
|
|
|
if (ret != 0) { |
|
|
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; |
|
|
|
} |
|
|
|
|