|
|
|
@@ -63,11 +63,11 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
|
|
|
|
bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, |
|
|
|
const std::vector<kernel::AddressPtr> & /*workspace*/, |
|
|
|
const std::vector<kernel::AddressPtr> & /*outputs*/) { |
|
|
|
const std::vector<kernel::AddressPtr> &outputs) { |
|
|
|
if (dtype_ == kNumberTypeFloat16) { |
|
|
|
LaunchKernel<float16>(inputs); |
|
|
|
LaunchKernel<float16>(inputs, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeFloat32) { |
|
|
|
LaunchKernel<float>(inputs); |
|
|
|
LaunchKernel<float>(inputs, outputs); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Only support float16, float32"; |
|
|
|
return false; |
|
|
|
@@ -76,7 +76,8 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs) { |
|
|
|
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
const std::vector<kernel::AddressPtr> &outputs) { |
|
|
|
auto x = reinterpret_cast<T *>(inputs[0]->addr); |
|
|
|
auto indices = reinterpret_cast<int *>(inputs[1]->addr); |
|
|
|
auto updates = reinterpret_cast<T *>(inputs[2]->addr); |
|
|
|
@@ -100,6 +101,10 @@ void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &input |
|
|
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; |
|
|
|
} |
|
|
|
} |
|
|
|
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, mem_size); |
|
|
|
if (ret != 0) { |
|
|
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ScatterNdUpdateCPUKernel::Check(const CNodePtr &kernel_node) { |
|
|
|
|