fix ci fix ci fix ci fix citags/v1.2.0-rc1
| @@ -0,0 +1,61 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| TensorScatterUpdateGpuFwdKernel, half, int) | |||||
| MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| TensorScatterUpdateGpuFwdKernel, float, int) | |||||
| MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeInt8) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt8) | |||||
| .AddOutputAttr(kNumberTypeInt8), | |||||
| TensorScatterUpdateGpuFwdKernel, char, int) | |||||
| MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt32), | |||||
| TensorScatterUpdateGpuFwdKernel, int, int) | |||||
| MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeUInt8) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeUInt8) | |||||
| .AddOutputAttr(kNumberTypeUInt8), | |||||
| TensorScatterUpdateGpuFwdKernel, uchar, int) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,212 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_SCATTER_UPDATE_GPU_KERNEL_H | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_SCATTER_UPDATE_GPU_KERNEL_H | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cuh" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T, typename S> | |||||
| class TensorScatterUpdateGpuFwdKernel : public GpuKernel { | |||||
| public: | |||||
| TensorScatterUpdateGpuFwdKernel() | |||||
| : input_size_(1), | |||||
| update_size_(1), | |||||
| indices_size_(1), | |||||
| output_size_(1), | |||||
| block_size_(1), | |||||
| indices_stride_(nullptr), | |||||
| work_shape_(nullptr), | |||||
| indices_dim_0_(0), | |||||
| indices_dim_1_(0), | |||||
| memcpy_flag_(false) {} | |||||
| ~TensorScatterUpdateGpuFwdKernel() { | |||||
| if (indices_stride_ != nullptr) { | |||||
| device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(indices_stride_)); | |||||
| } | |||||
| if (work_shape_ != nullptr) { | |||||
| device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(work_shape_)); | |||||
| } | |||||
| } | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| VARIABLE_NOT_USED(workspace); | |||||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||||
| S *indices = GetDeviceAddress<S>(inputs, 1); | |||||
| T *update = GetDeviceAddress<T>(inputs, 2); | |||||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||||
| if (!memcpy_flag_) { | |||||
| const size_t indices_len = sizeof(S) * vec_indices_stride_.size(); | |||||
| const size_t vec_work_len = sizeof(S) * vec_work_shape_.size(); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||||
| cudaMemcpyAsync(indices_stride_, &vec_indices_stride_[0], indices_len, | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpy failed in TensorScatterUpdateGpuFwdKernel::Launch."); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||||
| cudaMemcpyAsync(work_shape_, &vec_work_shape_[0], vec_work_len, cudaMemcpyHostToDevice, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpy failed in TensorScatterUpdateGpuFwdKernel::Launch."); | |||||
| memcpy_flag_ = true; | |||||
| } | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| kernel_node_, | |||||
| cudaMemsetAsync(output, static_cast<T>(0.0), output_size_, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemSet failed in TensorScatterUpdateGpuFwdKernel::Launch."); | |||||
| const size_t update_size = update_size_ / sizeof(T); | |||||
| const size_t output_size = output_size_ / sizeof(T); | |||||
| TensorScatterUpdate(input, indices, update, output, block_size_, update_size, output_size, indices_dim_0_, | |||||
| indices_dim_1_, indices_stride_, work_shape_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||||
| cudaMemcpyAsync(&output[0], &input[0], input_size_, cudaMemcpyDeviceToDevice, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpyAsync output failed"); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| kernel_node_ = kernel_node; | |||||
| memcpy_flag_ = false; | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 3) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but TensorScatterUpdate needs 3 inputs."; | |||||
| return false; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 1) { | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but TensorScatterUpdate has 1 output."; | |||||
| return false; | |||||
| } | |||||
| update_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||||
| indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| std::vector<size_t> shape_me = input_shapes_; | |||||
| (void)std::transform(shape_me.begin(), shape_me.end(), std::back_inserter(vec_work_shape_), | |||||
| [](const size_t &value) { return static_cast<S>(value); }); | |||||
| GetSize(); | |||||
| const size_t indices_len = sizeof(S) * vec_indices_stride_.size(); | |||||
| void *indices_stride_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len); | |||||
| if (indices_stride_work == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Failed to alloc indices_stride_work, size: " << indices_len; | |||||
| } | |||||
| indices_stride_ = static_cast<S *>(indices_stride_work); | |||||
| const size_t vec_work_len = sizeof(S) * vec_work_shape_.size(); | |||||
| void *work_shape_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(vec_work_len); | |||||
| if (work_shape_work == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Failed to alloc work_shape_work, size: " << vec_work_len; | |||||
| } | |||||
| work_shape_ = static_cast<S *>(work_shape_work); | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(input_size_); | |||||
| input_size_list_.push_back(indices_size_); | |||||
| input_size_list_.push_back(update_size_); | |||||
| output_size_list_.push_back(output_size_); | |||||
| return; | |||||
| } | |||||
| void GetSize() { | |||||
| input_size_ = sizeof(T); | |||||
| for (size_t i = 0; i < input_shapes_.size(); i++) { | |||||
| input_size_ *= input_shapes_[i]; | |||||
| } | |||||
| indices_size_ = sizeof(S); | |||||
| for (size_t i = 0; i < indices_shapes_.size(); i++) { | |||||
| indices_size_ *= indices_shapes_[i]; | |||||
| } | |||||
| update_size_ = sizeof(T); | |||||
| for (size_t i = 0; i < update_shapes_.size(); i++) { | |||||
| update_size_ *= update_shapes_[i]; | |||||
| } | |||||
| output_size_ = sizeof(T); | |||||
| for (size_t i = 0; i < output_shapes_.size(); i++) { | |||||
| output_size_ *= output_shapes_[i]; | |||||
| } | |||||
| // calculate indices dim 0/1 | |||||
| indices_dim_0_ = indices_shapes_[0]; | |||||
| indices_dim_1_ = indices_shapes_[indices_shapes_.size() - 1]; | |||||
| // calculate block_size | |||||
| for (size_t i = indices_dim_1_; i < output_shapes_.size(); i++) { | |||||
| block_size_ *= output_shapes_[i]; | |||||
| } | |||||
| // calculate indices_stride | |||||
| vec_indices_stride_.resize(indices_dim_1_, 0); | |||||
| vec_indices_stride_[indices_dim_1_ - 1] = block_size_; | |||||
| for (size_t i = indices_dim_1_ - 1; i > 0; --i) { | |||||
| vec_indices_stride_[i - 1] = vec_indices_stride_[i] * output_shapes_[i]; | |||||
| } | |||||
| } | |||||
| private: | |||||
| std::vector<size_t> update_shapes_; | |||||
| std::vector<size_t> indices_shapes_; | |||||
| std::vector<size_t> input_shapes_; | |||||
| std::vector<size_t> output_shapes_; | |||||
| std::vector<S> vec_indices_stride_; | |||||
| std::vector<S> vec_work_shape_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| size_t input_size_; | |||||
| size_t update_size_; | |||||
| size_t indices_size_; | |||||
| size_t output_size_; | |||||
| size_t block_size_; | |||||
| S *indices_stride_; | |||||
| S *work_shape_; | |||||
| size_t indices_dim_0_; | |||||
| size_t indices_dim_1_; | |||||
| bool memcpy_flag_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_SCATTER_UPDATE_GPU_KERNEL_H | |||||
| @@ -0,0 +1,84 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cuh" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T, typename S> | |||||
| __global__ void TensorScatterUpdateKernel(T *input, S *indices, T *update, T *output, const size_t block_size, | |||||
| const size_t input_size, const size_t output_size, const size_t indices_dim_0, | |||||
| const size_t indices_dim_1, S *indices_stride, S *work_shape) { | |||||
| int i, j; | |||||
| for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size; | |||||
| read_index += blockDim.x * gridDim.x) { | |||||
| size_t write_index = 0; | |||||
| bool out_bound = false; | |||||
| i = read_index / block_size; | |||||
| j = read_index % block_size; | |||||
| for (size_t k = 0; k < indices_dim_1; k++) { | |||||
| S indices_i = indices[i * indices_dim_1 + k]; | |||||
| out_bound |= indices_i >= work_shape[k]; | |||||
| write_index += indices_i * indices_stride[k]; | |||||
| } | |||||
| write_index += j; | |||||
| out_bound |= write_index >= output_size; | |||||
| if (!out_bound) { | |||||
| input[write_index] = update[read_index]; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T, typename S> | |||||
| void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, | |||||
| const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, | |||||
| S *work_shape, cudaStream_t stream) { | |||||
| TensorScatterUpdateKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(input, indices, update, output, | |||||
| block_size, input_size, output_size, | |||||
| indices_dim_0, indices_dim_1, | |||||
| indices_stride, work_shape); | |||||
| return; | |||||
| } | |||||
| template void TensorScatterUpdate<half, int>(half *input, int *indices, half *update, half *output, | |||||
| const size_t &block_size, const size_t &input_size, | |||||
| const size_t &output_size, const size_t &indices_dim_0, | |||||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||||
| cudaStream_t stream); | |||||
| template void TensorScatterUpdate<float, int>(float *input, int *indices, float *update, float *output, | |||||
| const size_t &block_size, const size_t &input_size, | |||||
| const size_t &output_size, const size_t &indices_dim_0, | |||||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||||
| cudaStream_t stream); | |||||
| template void TensorScatterUpdate<char, int>(char *input, int *indices, char *update, char *output, | |||||
| const size_t &block_size, const size_t &input_size, | |||||
| const size_t &output_size, const size_t &indices_dim_0, | |||||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||||
| cudaStream_t stream); | |||||
| template void TensorScatterUpdate<int, int>(int *input, int *indices, int *update, int *output, | |||||
| const size_t &block_size, const size_t &input_size, | |||||
| const size_t &output_size, const size_t &indices_dim_0, | |||||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||||
| cudaStream_t stream); | |||||
| template void TensorScatterUpdate<unsigned char, int>(unsigned char *input, int *indices, unsigned char *update, | |||||
| unsigned char *output, const size_t &block_size, | |||||
| const size_t &input_size, const size_t &output_size, | |||||
| const size_t &indices_dim_0, const size_t &indices_dim_1, | |||||
| int *indices_stride, int *work_shape, cudaStream_t stream); | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_UPDATE_IMPL_CUH | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_UPDATE_IMPL_CUH | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T, typename S> | |||||
| void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, | |||||
| const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, | |||||
| S *work_shape, cudaStream_t stream); | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_UPDATE_IMPL_CUH | |||||
| @@ -0,0 +1,79 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.scatter = P.TensorScatterUpdate() | |||||
| def construct(self, x, indices, update): | |||||
| return self.scatter(x, indices, update) | |||||
| def scatter_net(x, indices, update): | |||||
| scatter = Net() | |||||
| return scatter(Tensor(x), Tensor(indices), Tensor(update)).asnumpy() | |||||
| def test_scatter(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| arr_input = np.arange(21).reshape(3, 7).astype(np.float32) | |||||
| arr_indices = np.array([[0, 1], [1, 1], [0, 5], [0, 2], [2, 1]]).astype(np.int32) | |||||
| arr_update = np.array([3.2, 1.1, 5.3, -2.2, -1.0]).astype(np.float32) | |||||
| out = scatter_net(arr_input, arr_indices, arr_update) | |||||
| expected = np.array([[0, 3.2, -2.2, 3, 4, 5.3, 6], | |||||
| [7, 1.1, 9, 10, 11, 12, 13], | |||||
| [14, -1, 16, 17, 18, 19, 20]]).astype(np.float32) | |||||
| np.testing.assert_allclose(out, expected, rtol=1e-6) | |||||
| arr_input = np.arange(24).reshape(4, 2, 3).astype(np.float32) | |||||
| arr_indices = np.array([[0, 0, 0], [1, 1, 1], [0, 1, 1], [3, 0, 1]]).astype(np.int32) | |||||
| arr_update = np.array([-1, -2, -3, -4]).astype(np.float32) | |||||
| out = scatter_net(arr_input, arr_indices, arr_update) | |||||
| expected = np.array([[[-1, 1, 2], | |||||
| [3, -3, 5]], | |||||
| [[6, 7, 8], | |||||
| [9, -2, 11]], | |||||
| [[12, 13, 14], | |||||
| [15, 16, 17]], | |||||
| [[18, -4, 20], | |||||
| [21, 22, 23]]]).astype(np.float32) | |||||
| np.testing.assert_allclose(out, expected, rtol=1e-6) | |||||
| arr_input = np.arange(25).reshape(5, 5).astype(np.float32) | |||||
| arr_indices = np.array([[[0, 0], | |||||
| [1, 1], | |||||
| [2, 2], | |||||
| [3, 3], | |||||
| [4, 4]], | |||||
| [[0, 4], | |||||
| [1, 3], | |||||
| [2, 2], | |||||
| [3, 1], | |||||
| [4, 0]]]).astype(np.int32) | |||||
| arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.float32) | |||||
| out = scatter_net(arr_input, arr_indices, arr_update) | |||||
| expected = np.array([[11, 1, 2, 3, 66], | |||||
| [5, 22, 7, 77, 9], | |||||
| [10, 11, 33, 13, 14], | |||||
| [15, 99, 17, 44, 19], | |||||
| [100, 21, 22, 23, 55]]).astype(np.float32) | |||||
| np.testing.assert_allclose(out, expected, rtol=1e-6) | |||||