|
|
|
@@ -36,16 +36,19 @@ struct ComputeParams { |
|
|
|
size_t x_mem_size_{0}; |
|
|
|
}; |
|
|
|
|
|
|
|
class ScatterNdUpdateCPUKernel : public CPUKernel { |
|
|
|
class ScatterUpdateCPUKernel : public CPUKernel { |
|
|
|
public: |
|
|
|
ScatterNdUpdateCPUKernel() = default; |
|
|
|
~ScatterNdUpdateCPUKernel() override = default; |
|
|
|
ScatterUpdateCPUKernel() = default; |
|
|
|
~ScatterUpdateCPUKernel() override = default; |
|
|
|
|
|
|
|
void InitKernel(const CNodePtr &kernel_node) override; |
|
|
|
|
|
|
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, |
|
|
|
const std::vector<AddressPtr> &outputs) override; |
|
|
|
|
|
|
|
virtual void *ScatterUpdateRealData(const std::vector<AddressPtr> &inputs, |
|
|
|
const std::vector<kernel::AddressPtr> &outputs) = 0; |
|
|
|
|
|
|
|
private: |
|
|
|
template <typename T> |
|
|
|
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs); |
|
|
|
@@ -57,13 +60,25 @@ class ScatterNdUpdateCPUKernel : public CPUKernel { |
|
|
|
std::vector<int> out_strides_; |
|
|
|
}; |
|
|
|
|
|
|
|
class ScatterNdUpdateCPUKernel : public ScatterUpdateCPUKernel { |
|
|
|
protected: |
|
|
|
void *ScatterUpdateRealData(const std::vector<AddressPtr> &inputs, |
|
|
|
const std::vector<kernel::AddressPtr> &outputs) override; |
|
|
|
}; |
|
|
|
|
|
|
|
class TensorScatterUpdateCPUKernel : public ScatterUpdateCPUKernel { |
|
|
|
protected: |
|
|
|
void *ScatterUpdateRealData(const std::vector<AddressPtr> &inputs, |
|
|
|
const std::vector<kernel::AddressPtr> &outputs) override; |
|
|
|
}; |
|
|
|
|
|
|
|
MS_REG_CPU_KERNEL(ScatterNdUpdate, |
|
|
|
KernelAttr() |
|
|
|
.AddInputAttr(kNumberTypeFloat32) |
|
|
|
.AddInputAttr(kNumberTypeInt32) |
|
|
|
.AddInputAttr(kNumberTypeFloat32) |
|
|
|
.AddOutputAttr(kNumberTypeFloat32), |
|
|
|
ScatterNdUpdateCPUKernel); |
|
|
|
ScatterNdUpdateCPUKernel) |
|
|
|
|
|
|
|
MS_REG_CPU_KERNEL(TensorScatterUpdate, |
|
|
|
KernelAttr() |
|
|
|
@@ -71,7 +86,7 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate, |
|
|
|
.AddInputAttr(kNumberTypeInt32) |
|
|
|
.AddInputAttr(kNumberTypeFloat32) |
|
|
|
.AddOutputAttr(kNumberTypeFloat32), |
|
|
|
ScatterNdUpdateCPUKernel); |
|
|
|
TensorScatterUpdateCPUKernel) |
|
|
|
|
|
|
|
MS_REG_CPU_KERNEL(ScatterNdUpdate, |
|
|
|
KernelAttr() |
|
|
|
@@ -79,7 +94,7 @@ MS_REG_CPU_KERNEL(ScatterNdUpdate, |
|
|
|
.AddInputAttr(kNumberTypeInt32) |
|
|
|
.AddInputAttr(kNumberTypeFloat64) |
|
|
|
.AddOutputAttr(kNumberTypeFloat64), |
|
|
|
ScatterNdUpdateCPUKernel); |
|
|
|
ScatterNdUpdateCPUKernel) |
|
|
|
|
|
|
|
MS_REG_CPU_KERNEL(TensorScatterUpdate, |
|
|
|
KernelAttr() |
|
|
|
@@ -87,7 +102,7 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate, |
|
|
|
.AddInputAttr(kNumberTypeInt32) |
|
|
|
.AddInputAttr(kNumberTypeFloat64) |
|
|
|
.AddOutputAttr(kNumberTypeFloat64), |
|
|
|
ScatterNdUpdateCPUKernel); |
|
|
|
TensorScatterUpdateCPUKernel) |
|
|
|
|
|
|
|
MS_REG_CPU_KERNEL(ScatterNdUpdate, |
|
|
|
KernelAttr() |
|
|
|
@@ -103,7 +118,7 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate, |
|
|
|
.AddInputAttr(kNumberTypeInt32) |
|
|
|
.AddInputAttr(kNumberTypeInt32) |
|
|
|
.AddOutputAttr(kNumberTypeInt32), |
|
|
|
ScatterNdUpdateCPUKernel); |
|
|
|
TensorScatterUpdateCPUKernel) |
|
|
|
|
|
|
|
MS_REG_CPU_KERNEL(ScatterNdUpdate, |
|
|
|
KernelAttr() |
|
|
|
@@ -119,7 +134,7 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate, |
|
|
|
.AddInputAttr(kNumberTypeInt32) |
|
|
|
.AddInputAttr(kNumberTypeInt64) |
|
|
|
.AddOutputAttr(kNumberTypeInt64), |
|
|
|
ScatterNdUpdateCPUKernel); |
|
|
|
TensorScatterUpdateCPUKernel) |
|
|
|
} // namespace kernel |
|
|
|
} // namespace mindspore |
|
|
|
|
|
|
|
|