From: @tom__chen Reviewed-by: @liangchenghui,@robingrosman Signed-off-by: @liangchenghuitags/v1.2.0-rc1
| @@ -0,0 +1,106 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| #include "include/cuda_fp16.h" | |||
| __global__ void InitErrorCode(IndexAddErrorCode *error_code) { | |||
| *error_code = IndexAddErrorCode::kOk; | |||
| } | |||
| __global__ void ValidateIndexValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size, | |||
| IndexAddErrorCode *error_code) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < src_axis_size; pos += blockDim.x * gridDim.x) { | |||
| const int idx_value = index[pos]; | |||
| if (idx_value < 0 || idx_value >= dst_axis_size) { | |||
| *error_code = IndexAddErrorCode::kIndexOutOfRange; | |||
| return; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void IndexAddAtomic(T *dst, const int *index, const T *src, const size_t src_size, const size_t outer_size, | |||
| const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < src_size; pos += blockDim.x * gridDim.x) { | |||
| const size_t src_axis_idx = (pos / inner_size) % src_axis_size; | |||
| const size_t src_outer_idx = pos / (src_axis_size * inner_size); | |||
| const size_t dst_axis_idx = static_cast<size_t>(index[src_axis_idx]); | |||
| const size_t dst_inner_idx = pos % inner_size; | |||
| const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx; | |||
| MsAtomicAdd(&dst[dst_idx], src[pos]); | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void IndexAdd(T *dst, const int *index, const T *src, const size_t src_size, const size_t outer_size, | |||
| const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < src_size; pos += blockDim.x * gridDim.x) { | |||
| const size_t src_axis_idx = (pos / inner_size) % src_axis_size; | |||
| const size_t src_outer_idx = pos / (src_axis_size * inner_size); | |||
| const size_t dst_axis_idx = static_cast<size_t>(index[src_axis_idx]); | |||
| const size_t dst_inner_idx = pos % inner_size; | |||
| const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx; | |||
| dst[dst_idx] += src[pos]; | |||
| } | |||
| return; | |||
| } | |||
| void ValidateIndexAddInputValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size, | |||
| IndexAddErrorCode *error_code, cudaStream_t cuda_stream) { | |||
| InitErrorCode<<<1, 1, 0, cuda_stream>>>(error_code); | |||
| ValidateIndexValues<<<GET_BLOCKS(src_axis_size), GET_THREADS, 0, cuda_stream>>>(index, src_axis_size, dst_axis_size, | |||
| error_code); | |||
| } | |||
| template <typename T> | |||
| void CalIndexAdd(T *dst, const int *index, const T *src, const size_t outer_size, const size_t src_axis_size, | |||
| const size_t dst_axis_size, const size_t inner_size, const bool use_lock, cudaStream_t cuda_stream) { | |||
| size_t src_size = outer_size * src_axis_size * inner_size; | |||
| if (use_lock) { | |||
| IndexAddAtomic<<<GET_BLOCKS(src_size), GET_THREADS, 0, cuda_stream>>>(dst, index, src, src_size, outer_size, | |||
| src_axis_size, dst_axis_size, inner_size); | |||
| } else { | |||
| IndexAdd<<<GET_BLOCKS(src_size), GET_THREADS, 0, cuda_stream>>>(dst, index, src, src_size, outer_size, | |||
| src_axis_size, dst_axis_size, inner_size); | |||
| } | |||
| return; | |||
| } | |||
| template void CalIndexAdd<double>(double *dst, const int *index, const double *src, const size_t outer_size, | |||
| const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, | |||
| cudaStream_t cuda_stream); | |||
| template void CalIndexAdd<float>(float *dst, const int *index, const float *src, const size_t outer_size, | |||
| const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, | |||
| cudaStream_t cuda_stream); | |||
| template void CalIndexAdd<half>(half *dst, const int *index, const half *src, const size_t outer_size, | |||
| const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, | |||
| cudaStream_t cuda_stream); | |||
| template void CalIndexAdd<int>(int *dst, const int *index, const int *src, const size_t outer_size, | |||
| const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, | |||
| cudaStream_t cuda_stream); | |||
| template void CalIndexAdd<int16_t>(int16_t *dst, const int *index, const int16_t *src, const size_t outer_size, | |||
| const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, | |||
| cudaStream_t cuda_stream); | |||
| template void CalIndexAdd<int8_t>(int8_t *dst, const int *index, const int8_t *src, const size_t outer_size, | |||
| const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, | |||
| cudaStream_t cuda_stream); | |||
| template void CalIndexAdd<uint8_t>(uint8_t *dst, const int *index, const uint8_t *src, const size_t outer_size, | |||
| const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, | |||
| cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_ | |||
| enum class IndexAddErrorCode { | |||
| kOk = 0, | |||
| kIndexOutOfRange | |||
| }; | |||
| void ValidateIndexAddInputValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size, | |||
| IndexAddErrorCode *error_code, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalIndexAdd(T *dst, const int *index, const T *src, const size_t outer_size, const size_t src_axis_size, | |||
| const size_t dst_axis_size, const size_t inner_size, const bool use_lock, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_ | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/math/index_add_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(IndexAdd, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddOutputAttr(kNumberTypeFloat64), | |||
| IndexAddGpuKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE(IndexAdd, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| IndexAddGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(IndexAdd, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| IndexAddGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(IndexAdd, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| IndexAddGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE(IndexAdd, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt16) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt16) | |||
| .AddOutputAttr(kNumberTypeInt16), | |||
| IndexAddGpuKernel, int16_t) | |||
| MS_REG_GPU_KERNEL_ONE(IndexAdd, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt8) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt8) | |||
| .AddOutputAttr(kNumberTypeInt8), | |||
| IndexAddGpuKernel, int8_t) | |||
| MS_REG_GPU_KERNEL_ONE(IndexAdd, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddOutputAttr(kNumberTypeUInt8), | |||
| IndexAddGpuKernel, uint8_t) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,155 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_INDEX_ADD_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_INDEX_ADD_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class IndexAddGpuKernel : public GpuKernel { | |||
| public: | |||
| IndexAddGpuKernel() | |||
| : dst_size_(0), | |||
| index_size_(0), | |||
| src_size_(0), | |||
| output_size_(0), | |||
| outer_size_(0), | |||
| src_axis_size_(0), | |||
| dst_axis_size_(0), | |||
| inner_size_(0), | |||
| use_lock_(true), | |||
| check_index_bound_(true) {} | |||
| ~IndexAddGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *dst = GetDeviceAddress<T>(inputs, 0); | |||
| int *index = GetDeviceAddress<int>(inputs, 1); | |||
| T *src = GetDeviceAddress<T>(inputs, 2); | |||
| T *dst_out = GetDeviceAddress<T>(outputs, 0); | |||
| if (check_index_bound_) { | |||
| IndexAddErrorCode *error_code_addr = GetDeviceAddress<IndexAddErrorCode>(workspace, 0); | |||
| IndexAddErrorCode error_code = IndexAddErrorCode::kOk; | |||
| ValidateIndexAddInputValues(index, src_axis_size_, dst_axis_size_, error_code_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDA_RET_WITH_ERROR(kernel_node_, | |||
| cudaMemcpyAsync(&error_code, error_code_addr, sizeof(IndexAddErrorCode), | |||
| cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "Failed to copy error code to host."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed"); | |||
| LogExceptionIfNotOk(error_code); | |||
| } | |||
| CalIndexAdd(dst, index, src, outer_size_, src_axis_size_, dst_axis_size_, inner_size_, use_lock_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(&dst_out[0], &dst[0], dst_size_, cudaMemcpyDeviceToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync output failed"); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| kernel_node_ = kernel_node; | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 3) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but index add needs 3 inputs."; | |||
| return false; | |||
| } | |||
| std::vector<size_t> dst_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| std::vector<size_t> index_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| std::vector<size_t> src_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||
| int64_t src_rank = src_shape.size(); | |||
| int64_t axis = GetAttr<int64_t>(kernel_node, "axis"); | |||
| if (axis < 0) { | |||
| axis += src_rank; | |||
| } | |||
| outer_size_ = 1; | |||
| for (int64_t i = axis - 1; i >= 0; i--) { | |||
| outer_size_ *= src_shape[i]; | |||
| } | |||
| inner_size_ = 1; | |||
| for (int64_t i = axis + 1; i < src_rank; i++) { | |||
| inner_size_ *= src_shape[i]; | |||
| } | |||
| src_axis_size_ = src_shape[axis]; | |||
| dst_axis_size_ = dst_shape[axis]; | |||
| dst_size_ = sizeof(T); | |||
| for (auto x : dst_shape) { | |||
| dst_size_ *= x; | |||
| } | |||
| index_size_ = sizeof(int); | |||
| for (auto x : index_shape) { | |||
| index_size_ *= x; | |||
| } | |||
| src_size_ = sizeof(T); | |||
| for (auto x : src_shape) { | |||
| src_size_ *= x; | |||
| } | |||
| output_size_ = dst_size_; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(dst_size_); | |||
| input_size_list_.push_back(index_size_); | |||
| input_size_list_.push_back(src_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| workspace_size_list_.push_back(sizeof(IndexAddErrorCode)); | |||
| } | |||
| private: | |||
| void LogExceptionIfNotOk(IndexAddErrorCode error_code) { | |||
| switch (error_code) { | |||
| case IndexAddErrorCode::kOk: | |||
| return; | |||
| case IndexAddErrorCode::kIndexOutOfRange: | |||
| MS_LOG(EXCEPTION) << "gpu IndexAdd op error: values of index tensor is out of range"; | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "gpu IndexAdd op unknown error"; | |||
| } | |||
| } | |||
| size_t dst_size_; | |||
| size_t index_size_; | |||
| size_t src_size_; | |||
| size_t output_size_; | |||
| size_t outer_size_; | |||
| size_t src_axis_size_; | |||
| size_t dst_axis_size_; | |||
| size_t inner_size_; | |||
| bool use_lock_; | |||
| bool check_index_bound_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_INDEX_ADD_GPU_KERNEL_H_ | |||
| @@ -56,7 +56,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||
| Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, | |||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | |||
| Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, | |||
| MatrixInverse) | |||
| MatrixInverse, IndexAdd) | |||
| from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | |||
| RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler, | |||
| @@ -418,6 +418,7 @@ __all__ = [ | |||
| "SparseToDense", | |||
| "MatrixInverse", | |||
| "Range", | |||
| "IndexAdd", | |||
| ] | |||
| __all__.sort() | |||
| @@ -4199,3 +4199,68 @@ class MatrixInverse(PrimitiveWithInfer): | |||
| validator.check_int(len(x_shape), 2, Rel.GE, self.name, None) | |||
| validator.check_equal_int(x_shape[-1], x_shape[-2], self.name, None) | |||
| return x_shape | |||
| class IndexAdd(PrimitiveWithInfer): | |||
| """ | |||
| Adds tenosr y to specified axis and indices of tensor x. | |||
| Args: | |||
| axis (int): The dimension along wich to index. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The input tensor to add to, with data type float64, float32, float16, int32, int16, | |||
| int8, uint8. | |||
| - **indices** (Tensor) - The index of `input_x` on the `axis`th dimension to add to, with data type int32. | |||
| The `indices` must be 1D with the size same as the size of the `axis`th dimension of `input_y`. The values | |||
| of `indices` should be in the range of 0 to the size of the `axis`th dimension of `input_x`. | |||
| - **input_y** (Tensor) - The input tensor with the value to add. Must have same data type as `input_x`. | |||
| The shape must be the same as `input_x` except the `axis`th dimension. | |||
| Outputs: | |||
| Tensor, has the same shape and dtype as input_x. | |||
| Supported Platforms: | |||
| ``GPU`` | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [6, 7, 8]]), mindspore.float32) | |||
| >>> input_y = Tensor(np.array([[0.5, 1.0], [1.0, 1.5], [2.0, 2.5]]), mindspore.float32) | |||
| >>> indices = Tensor(np.array([0, 2]), mindspore.int32) | |||
| >>> index_add = ops.IndexAdd(axis=1) | |||
| >>> output = index_add(input_x, indices, input_y) | |||
| >>> print(output) | |||
| [[ 1.5 2. 4. ] | |||
| [ 5. 5. 7.5] | |||
| [ 8. 7. 10.5]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, axis, use_lock=True, check_index_bound=True): | |||
| """Initialize InplaceAdd""" | |||
| self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | |||
| self.axis = axis | |||
| validator.check_value_type('axis', axis, [int], self.name) | |||
| def infer_dtype(self, x_dtype, idx_type, y_dtype): | |||
| args = {'x': x_dtype, 'y': y_dtype} | |||
| valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.int32, mstype.int16, mstype.int8, | |||
| mstype.uint8] | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) | |||
| valid_idx_type = [mstype.int32] | |||
| validator.check_tensor_dtype_valid("idx_type", idx_type, valid_idx_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape, idx_shape, y_shape): | |||
| validator.check("x rank", len(x_shape), "y rank", len(y_shape), Rel.EQ, self.name) | |||
| validator.check("size of indices", idx_shape[0], "dimension of y[axis]", y_shape[self.axis], | |||
| Rel.EQ, self.name) | |||
| x_rank = len(x_shape) | |||
| validator.check_int_range(self.axis, -x_rank - 1, x_rank, Rel.INC_BOTH, 'axis', self.name) | |||
| axis = self.axis if self.axis >= 0 else x_rank + self.axis | |||
| for dim in range(x_rank): | |||
| if dim == axis: | |||
| validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.GE, self.name) | |||
| else: | |||
| validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.EQ, self.name) | |||
| return x_shape | |||
| @@ -0,0 +1,259 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| class NetIndexAdd(nn.Cell): | |||
| def __init__(self, axis): | |||
| super(NetIndexAdd, self).__init__() | |||
| self.index_add = P.IndexAdd(axis) | |||
| def construct(self, x, idx, y): | |||
| z = self.index_add(x, idx, y) | |||
| return z | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add(): | |||
| x = np.arange(2 * 3 * 4 * 4).reshape(2, 3, 4, 4).astype(np.float32) | |||
| y0 = np.ones((1, 3, 4, 4), dtype=np.float32) | |||
| idx0 = np.array([1]).astype(np.int32) | |||
| axis0 = 0 | |||
| expect = np.copy(x) | |||
| expect[idx0, :, :, :] = expect[idx0, :, :, :] + y0 | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| net = NetIndexAdd(axis0) | |||
| output = net(Tensor(x), Tensor(idx0), Tensor(y0)) | |||
| assert (output.asnumpy() == expect).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| net = NetIndexAdd(axis0) | |||
| output = net(Tensor(x), Tensor(idx0), Tensor(y0)) | |||
| assert (output.asnumpy() == expect).all() | |||
| y1 = np.ndarray((2, 2, 4, 4)).astype(np.float32) | |||
| y1.fill(0.1) | |||
| idx1 = np.array([0, 2]).astype(np.int32) | |||
| axis1 = 1 | |||
| expect = np.copy(x) | |||
| expect[:, idx1, :, :] = expect[:, idx1, :, :] + y1 | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| net = NetIndexAdd(axis1) | |||
| output = net(Tensor(x), Tensor(idx1), Tensor(y1)) | |||
| assert (output.asnumpy() == expect).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| net = NetIndexAdd(axis1) | |||
| output = net(Tensor(x), Tensor(idx1), Tensor(y1)) | |||
| assert (output.asnumpy() == expect).all() | |||
| y2 = np.ones((2, 3, 2, 4)).astype(np.float32) | |||
| y2.fill(5.5) | |||
| idx2 = np.array([1, 3]).astype(np.int32) | |||
| axis2 = 2 | |||
| expect = np.copy(x) | |||
| expect[:, :, idx2, :] = expect[:, :, idx2, :] + y2 | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| net = NetIndexAdd(axis2) | |||
| output = net(Tensor(x), Tensor(idx2), Tensor(y2)) | |||
| assert (output.asnumpy() == expect).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| net = NetIndexAdd(axis2) | |||
| output = net(Tensor(x), Tensor(idx2), Tensor(y2)) | |||
| assert (output.asnumpy() == expect).all() | |||
| y3 = np.ones((2, 3, 4, 3)).astype(np.float32) | |||
| y3.fill(1000.00) | |||
| idx3 = np.array([0, 2, 3]).astype(np.int32) | |||
| axis3 = 3 | |||
| expect = np.copy(x) | |||
| expect[:, :, :, idx3] = expect[:, :, :, idx3] + y3 | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| net = NetIndexAdd(axis3) | |||
| output = net(Tensor(x), Tensor(idx3), Tensor(y3)) | |||
| assert (output.asnumpy() == expect).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| net = NetIndexAdd(axis3) | |||
| output = net(Tensor(x), Tensor(idx3), Tensor(y3)) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_float16(): | |||
| x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.float16) | |||
| y = np.ones((2, 2, 4), dtype=np.float16) | |||
| idx = np.array([0, 2]).astype(np.int32) | |||
| axis = 1 | |||
| expect = np.copy(x) | |||
| expect[:, idx, :] = expect[:, idx, :] + y | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| net = NetIndexAdd(axis) | |||
| output = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert (output.asnumpy() == expect).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| net = NetIndexAdd(axis) | |||
| output = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_int32(): | |||
| x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.int32) | |||
| y = np.ones((2, 2, 4), dtype=np.int32) | |||
| idx = np.array([0, 2]).astype(np.int32) | |||
| axis = 1 | |||
| expect = np.copy(x) | |||
| expect[:, idx, :] = expect[:, idx, :] + y | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| net = NetIndexAdd(axis) | |||
| output = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert (output.asnumpy() == expect).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| net = NetIndexAdd(axis) | |||
| output = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_int8(): | |||
| x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.int8) | |||
| y = np.ones((2, 2, 4), dtype=np.int8) | |||
| idx = np.array([0, 2]).astype(np.int32) | |||
| axis = 1 | |||
| expect = np.copy(x) | |||
| expect[:, idx, :] = expect[:, idx, :] + y | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| net = NetIndexAdd(axis) | |||
| output = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert (output.asnumpy() == expect).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| net = NetIndexAdd(axis) | |||
| output = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_uint8(): | |||
| x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.uint8) | |||
| y = np.ones((2, 2, 4), dtype=np.uint8) | |||
| idx = np.array([0, 2]).astype(np.int32) | |||
| axis = 1 | |||
| expect = np.copy(x) | |||
| expect[:, idx, :] = expect[:, idx, :] + y | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| net = NetIndexAdd(axis) | |||
| output = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert (output.asnumpy() == expect).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| net = NetIndexAdd(axis) | |||
| output = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_float64(): | |||
| x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.float64) | |||
| y = np.ones((2, 2, 4), dtype=np.float64) | |||
| idx = np.array([0, 2]).astype(np.int32) | |||
| axis = 1 | |||
| expect = np.copy(x) | |||
| expect[:, idx, :] = expect[:, idx, :] + y | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| net = NetIndexAdd(axis) | |||
| output = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert (output.asnumpy() == expect).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| net = NetIndexAdd(axis) | |||
| output = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_int16(): | |||
| x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.int16) | |||
| y = np.ones((2, 2, 4), dtype=np.int16) | |||
| idx = np.array([0, 2]).astype(np.int32) | |||
| axis = 1 | |||
| expect = np.copy(x) | |||
| expect[:, idx, :] = expect[:, idx, :] + y | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| net = NetIndexAdd(axis) | |||
| output = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert (output.asnumpy() == expect).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| net = NetIndexAdd(axis) | |||
| output = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_index_add_invalid_inputs(): | |||
| x = np.arange(2 * 3 * 4).reshape(2, 3, 4).astype(np.uint8) | |||
| y = np.ones((2, 2, 4), dtype=np.uint8) | |||
| with pytest.raises(TypeError): | |||
| #axis not int | |||
| net = NetIndexAdd(1.0) | |||
| #x and y don't have the same type | |||
| y = np.ones((2, 2, 4), dtype=np.float32) | |||
| idx = np.array([0, 1]).astype(np.int32) | |||
| net = NetIndexAdd(1) | |||
| _ = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| with pytest.raises(ValueError): | |||
| #index size not the same as len(y[axis]) | |||
| idx = np.array([0]).astype(np.int32) | |||
| net = NetIndexAdd(1) | |||
| _ = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| #x and y don't have same rank | |||
| y = np.ones((2, 2), dtype=np.uint8) | |||
| idx = np.array([0, 1]).astype(np.int32) | |||
| net = NetIndexAdd(1) | |||
| _ = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| #x and y don't have same shape on dimensions other than axis-th dimension | |||
| y = np.ones((2, 2, 5), dtype=np.uint8) | |||
| idx = np.array([0, 1]).astype(np.int32) | |||
| net = NetIndexAdd(1) | |||
| _ = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| with pytest.raises(RuntimeError) as info: | |||
| #index value not in the range of 0 to len(x[axis]) | |||
| idx = np.array([5, 6]).astype(np.int32) | |||
| net = NetIndexAdd(1) | |||
| _ = net(Tensor(x), Tensor(idx), Tensor(y)) | |||
| assert "out of range" in str(info.value) | |||