From: @xcnick Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -18,48 +18,82 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void ArgmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| if (shape.size() != 2) { | |||
| MS_LOG(EXCEPTION) << "argmax kernel dims invalid " << shape.size(); | |||
| namespace { | |||
| size_t get_element_num(const std::vector<size_t> &shape) { | |||
| size_t size = 1; | |||
| for (size_t i = 0; i < shape.size(); i++) { | |||
| size *= shape[i]; | |||
| } | |||
| batch_size_ = shape[0]; | |||
| class_num_ = shape[1]; | |||
| return size; | |||
| } | |||
| int64_t axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS); | |||
| if (axis != -1 && axis != 1) { | |||
| MS_LOG(EXCEPTION) << "argmax kernel not support axis " << axis; | |||
| template <typename T> | |||
| bool check_validation(const std::vector<size_t> &shape, const size_t num_before_axis, const size_t num_after_axis, | |||
| const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (inputs.size() != 1 || outputs.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Wrong number of inputs or outputs!"; | |||
| return false; | |||
| } | |||
| size_t data_size = sizeof(T); | |||
| size_t input_size = get_element_num(shape) * data_size; | |||
| size_t output_num = num_before_axis * num_after_axis; | |||
| size_t output_size = output_num * sizeof(int); | |||
| if (inputs[0]->size != input_size || outputs[0]->size != output_size) { | |||
| MS_LOG(EXCEPTION) << "invalid input or output data size!"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace | |||
| bool ArgmaxCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspaces*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (inputs.empty() || outputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "input or output empty!"; | |||
| template <typename T> | |||
| void ArgmaxCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| size_t shape_len = shape_.size(); | |||
| int64_t axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS); | |||
| axis += shape_len; | |||
| if (axis < 0) { | |||
| MS_LOG(EXCEPTION) << "Invalid axis:" << axis << ", should in range [-1, " << shape_len - 1 << "]"; | |||
| } | |||
| axis = axis % static_cast<int64_t>(shape_len); | |||
| num_before_axis_ = 1; | |||
| num_after_axis_ = 1; | |||
| for (size_t i = 0; i < shape_len; i++) { | |||
| if (static_cast<int64_t>(i) < axis) { | |||
| num_before_axis_ *= shape_[i]; | |||
| } else if (static_cast<int64_t>(i) > axis) { | |||
| num_after_axis_ *= shape_[i]; | |||
| } | |||
| } | |||
| dim_axis_ = shape_[axis]; | |||
| } | |||
| size_t batch_float_size = batch_size_ * sizeof(float); | |||
| size_t batch_class_float_size = class_num_ * batch_float_size; | |||
| if (inputs[0]->size != batch_class_float_size || outputs[0]->size != batch_float_size) { | |||
| MS_LOG(EXCEPTION) << "invalid input or output data size!"; | |||
| template <typename T> | |||
| bool ArgmaxCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspaces*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (!check_validation<T>(shape_, num_before_axis_, num_after_axis_, inputs, outputs)) { | |||
| return false; | |||
| } | |||
| auto input = reinterpret_cast<float *>(inputs[0]->addr); | |||
| auto output = reinterpret_cast<int *>(outputs[0]->addr); | |||
| size_t row_start = 0; | |||
| for (size_t i = 0; i < batch_size_; ++i) { | |||
| size_t max_index = 0; | |||
| float max_value = input[row_start]; | |||
| for (size_t j = 1; j < class_num_; ++j) { | |||
| size_t index = row_start + j; | |||
| if (input[index] > max_value) { | |||
| max_value = input[index]; | |||
| max_index = j; | |||
| auto input = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto output = reinterpret_cast<int32_t *>(outputs[0]->addr); | |||
| for (size_t i = 0; i < num_before_axis_; i++) { | |||
| size_t src_index_i = i * dim_axis_ * num_after_axis_; | |||
| for (size_t j = 0; j < num_after_axis_; j++) { | |||
| std::vector<float> array_axis; | |||
| size_t src_index_j = src_index_i + j; | |||
| for (size_t k = 0; k < dim_axis_; k++) { | |||
| size_t src_index_k = k * num_after_axis_ + src_index_j; | |||
| array_axis.push_back(static_cast<float>(input[src_index_k])); | |||
| } | |||
| auto max_ops = std::max_element(array_axis.begin(), array_axis.end()); | |||
| auto max_index = static_cast<int32_t>(std::distance(array_axis.begin(), max_ops)); | |||
| auto dst_index = i * num_after_axis_ + j; | |||
| output[dst_index] = max_index; | |||
| } | |||
| output[i] = SizeToInt(max_index); | |||
| row_start += class_num_; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -22,6 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class ArgmaxCPUKernel : public CPUKernel { | |||
| public: | |||
| ArgmaxCPUKernel() = default; | |||
| @@ -33,12 +34,16 @@ class ArgmaxCPUKernel : public CPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| private: | |||
| size_t class_num_{0}; | |||
| size_t batch_size_{0}; | |||
| std::vector<size_t> shape_; | |||
| size_t num_before_axis_; | |||
| size_t num_after_axis_; | |||
| size_t dim_axis_; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), | |||
| ArgmaxCPUKernel); | |||
| MS_REG_CPU_KERNEL_T(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), | |||
| ArgmaxCPUKernel, float); | |||
| MS_REG_CPU_KERNEL_T(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), | |||
| ArgmaxCPUKernel, float16); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -18,9 +18,9 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), | |||
| ArgmaxGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), | |||
| ArgmaxGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_TWO(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), | |||
| ArgmaxGpuKernel, float, int) | |||
| MS_REG_GPU_KERNEL_TWO(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), | |||
| ArgmaxGpuKernel, half, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -23,11 +23,10 @@ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| #define ARGMAX_MAX_DIMENSION 2 | |||
| template <typename T> | |||
| template <typename T, typename S> | |||
| class ArgmaxGpuKernel : public GpuKernel { | |||
| public: | |||
| ArgmaxGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0), batch_size_(0), channel_size_(0), axis_(0) {} | |||
| ArgmaxGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0), bound_(0), outer_size_(0), inner_size_(0) {} | |||
| ~ArgmaxGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -37,47 +36,38 @@ class ArgmaxGpuKernel : public GpuKernel { | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| int *output = GetDeviceAddress<int>(outputs, 0); | |||
| CalArgmax(input, SizeToInt(batch_size_), SizeToInt(channel_size_), axis_, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| S *output = GetDeviceAddress<S>(outputs, 0); | |||
| CalArgmax(input, bound_, outer_size_, inner_size_, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but argmax needs 1 input."; | |||
| return false; | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| int64_t dims = shape.size(); | |||
| int64_t axis = GetAttr<int64_t>(kernel_node, "axis"); | |||
| if (axis < 0) { | |||
| axis += dims; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but argmax needs 1 output."; | |||
| return false; | |||
| input_size_ = sizeof(T); | |||
| for (auto x : shape) { | |||
| input_size_ *= x; | |||
| } | |||
| auto output_type = GetValue<TypePtr>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("output_type")); | |||
| if (output_type->type_id() != TypeId::kNumberTypeInt32) { | |||
| MS_LOG(EXCEPTION) << "Argmax only supports int32 output type."; | |||
| output_size_ = sizeof(S); | |||
| for (auto x : output_shape) { | |||
| output_size_ *= x; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (input_shape.size() > ARGMAX_MAX_DIMENSION) { | |||
| MS_LOG(EXCEPTION) << "Input is " << input_shape.size() << "-D, but Argmax supports max " << ARGMAX_MAX_DIMENSION | |||
| << "-D inputs."; | |||
| bound_ = static_cast<S>(shape[axis]); | |||
| if (shape[axis] != static_cast<size_t>(bound_)) { | |||
| MS_LOG(EXCEPTION) << "Bound's shape is larger than index type and overflows when casting."; | |||
| } | |||
| axis_ = GetAttr<int64_t>(kernel_node, "axis"); | |||
| if (axis_ < 0) { | |||
| axis_ += static_cast<int64_t>(input_shape.size()); | |||
| outer_size_ = 1; | |||
| for (int64_t i = axis - 1; i >= 0; i--) { | |||
| outer_size_ *= shape[i]; | |||
| } | |||
| if (input_shape.size() == 1) { | |||
| batch_size_ = 0; | |||
| channel_size_ = input_shape[0]; | |||
| input_size_ = sizeof(T) * channel_size_; | |||
| output_size_ = sizeof(int); | |||
| } else { | |||
| batch_size_ = input_shape[0]; | |||
| channel_size_ = input_shape[1]; | |||
| input_size_ = sizeof(T) * batch_size_ * channel_size_; | |||
| output_size_ = (axis_ == 1) ? sizeof(int) * batch_size_ : sizeof(int) * channel_size_; | |||
| inner_size_ = 1; | |||
| for (int64_t i = axis + 1; i < dims; i++) { | |||
| inner_size_ *= shape[i]; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| @@ -96,9 +86,9 @@ class ArgmaxGpuKernel : public GpuKernel { | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| size_t batch_size_; | |||
| size_t channel_size_; | |||
| int64_t axis_; | |||
| S bound_; | |||
| size_t outer_size_; | |||
| size_t inner_size_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -17,72 +17,36 @@ | |||
| #include "argmax_impl.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| #include "include/cuda_fp16.h" | |||
| template <typename T> | |||
| __global__ void Argmax1D(const T *input, const int channel_size, int *output) { | |||
| int max_index = 0; | |||
| T max = input[0]; | |||
| for (int pos = 1; pos < channel_size; pos++) { | |||
| if (max < input[pos]) { | |||
| max = input[pos]; | |||
| max_index = pos; | |||
| template <typename T, typename S> | |||
| __global__ void Argmax(const T *input, const S bound, const size_t outer_size, | |||
| const size_t inner_size, S *output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outer_size * inner_size; | |||
| pos += gridDim.x * blockDim.x) { | |||
| size_t x = pos / inner_size % outer_size; | |||
| size_t y = pos % inner_size; | |||
| S idx = 0; | |||
| size_t input_offset = x * bound * inner_size + 0 * inner_size + y; | |||
| T max_data = input[input_offset]; | |||
| for (S i = 1; i < bound; i++) { | |||
| input_offset = x * bound * inner_size + i * inner_size + y; | |||
| auto input_data = input[input_offset]; | |||
| idx = input_data > max_data ? i : idx; | |||
| max_data = input_data > max_data ? input_data : max_data; | |||
| } | |||
| output[pos] = idx; | |||
| } | |||
| output[0] = max_index; | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void ArgmaxDefault2D(const T *input, const int batch_size, const int channel_size, int *output) { | |||
| int pos; | |||
| int max_index; | |||
| T max; | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += blockDim.x * gridDim.x) { | |||
| max = input[i * channel_size]; | |||
| max_index = 0; | |||
| for (int j = 1; j < channel_size; j++) { | |||
| pos = i * channel_size + j; | |||
| if (max < input[pos]) { | |||
| max = input[pos]; | |||
| max_index = j; | |||
| } | |||
| } | |||
| output[i] = max_index; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void ArgmaxAxis2D(const T *input, const int batch_size, const int channel_size, int *output) { | |||
| int pos; | |||
| int max_index; | |||
| T max; | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { | |||
| max = input[i]; | |||
| max_index = 0; | |||
| for (int j = 1; j < batch_size; j++) { | |||
| pos = j * channel_size + i; | |||
| if (max < input[pos]) { | |||
| max = input[pos]; | |||
| max_index = j; | |||
| } | |||
| } | |||
| output[i] = max_index; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalArgmax(const T *input, const int batch_size, const int channel_size, const int64_t axis, int *output, | |||
| cudaStream_t cuda_stream) { | |||
| if (batch_size == 0) { | |||
| Argmax1D<<<1, 1, 0, cuda_stream>>>(input, channel_size, output); | |||
| } else if (axis == 1) { | |||
| ArgmaxDefault2D<<<GET_BLOCKS(batch_size), GET_THREADS, 0, cuda_stream>>>(input, batch_size, channel_size, output); | |||
| } else { | |||
| ArgmaxAxis2D<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(input, batch_size, channel_size, output); | |||
| } | |||
| template <typename T, typename S> | |||
| void CalArgmax(const T *input, const S bound, const size_t outer_size, const size_t inner_size, | |||
| S *output, cudaStream_t cuda_stream) { | |||
| Argmax<<<GET_BLOCKS(outer_size), GET_THREADS, 0, cuda_stream>>>(input, bound, outer_size, inner_size, | |||
| output); | |||
| return; | |||
| } | |||
| template void CalArgmax<float>(const float *input, const int batch_size, const int channel_size, const int64_t axis, | |||
| int *output, cudaStream_t cuda_stream); | |||
| template void CalArgmax<half>(const half *input, const int batch_size, const int channel_size, const int64_t axis, | |||
| int *output, cudaStream_t cuda_stream); | |||
| template void CalArgmax<float, int>(const float *input, const int bound, const size_t outer_size, | |||
| const size_t inner_size, int *output, cudaStream_t cuda_stream); | |||
| template void CalArgmax<half, int>(const half *input, const int bound, const size_t outer_size, | |||
| const size_t inner_size, int *output, cudaStream_t cuda_stream); | |||
| @@ -16,8 +16,8 @@ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_ARGMAX_IMPL_CUH_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_ARGMAX_IMPL_CUH_ | |||
| template <typename T> | |||
| void CalArgmax(const T *input, const int batch_size, const int channel_size, const int64_t axis, int *output, | |||
| template <typename T, typename S> | |||
| void CalArgmax(const T *input, const S bound, const size_t outer_size, const size_t inner_size, S *output, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_ARGMAX_IMPL_CUH_ | |||
| @@ -13,6 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import random | |||
| from functools import reduce | |||
| import numpy as np | |||
| import pytest | |||
| @@ -20,33 +22,59 @@ import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.ops import operations as P | |||
| import mindspore.ops as ops | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetArgmax(nn.Cell): | |||
| def __init__(self): | |||
| def __init__(self, axis=0): | |||
| super(NetArgmax, self).__init__() | |||
| self.argmax = P.Argmax(output_type=mstype.int32) | |||
| x = Tensor(np.array([[1., 20., 5.], | |||
| [67., 8., 9.], | |||
| [130., 24., 15.]]).astype(np.float32)) | |||
| self.x = Parameter(initializer(x, x.shape), name='x') | |||
| self.argmax = ops.Argmax(axis=axis, output_type=mstype.int32) | |||
| def construct(self): | |||
| return self.argmax(self.x) | |||
| def construct(self, x): | |||
| return self.argmax(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_argmax(): | |||
| Argmax = NetArgmax() | |||
| output = Argmax() | |||
| print("================================") | |||
| def test_argmax_1d(): | |||
| x = Tensor(np.array([1., 20., 5.]).astype(np.float32)) | |||
| Argmax = NetArgmax(axis=0) | |||
| output = Argmax(x) | |||
| expect = np.array([1]).astype(np.float32) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_argmax_2d(): | |||
| x = Tensor(np.array([[1., 20., 5.], | |||
| [67., 8., 9.], | |||
| [130., 24., 15.]]).astype(np.float32)) | |||
| Argmax_axis_0 = NetArgmax(axis=0) | |||
| output = Argmax_axis_0(x) | |||
| expect = np.array([2, 2, 2]).astype(np.float32) | |||
| assert (output.asnumpy() == expect).all() | |||
| Argmax_axis_1 = NetArgmax(axis=1) | |||
| output = Argmax_axis_1(x) | |||
| expect = np.array([1, 0, 0]).astype(np.float32) | |||
| print(output) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_argmax_high_dims(): | |||
| for dim in range(3, 10): | |||
| shape = np.random.randint(1, 10, size=dim) | |||
| x = np.random.randn(reduce(lambda x, y: x * y, shape)).astype(np.float32) | |||
| x = x.reshape(shape) | |||
| rnd_axis = random.randint(-dim + 1, dim - 1) | |||
| Argmax = NetArgmax(axis=rnd_axis) | |||
| ms_output = Argmax(Tensor(x)) | |||
| np_output = np.argmax(x, axis=rnd_axis) | |||
| assert (ms_output.asnumpy() == np_output).all() | |||
| @@ -13,6 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import random | |||
| from functools import reduce | |||
| import numpy as np | |||
| import pytest | |||
| @@ -20,43 +22,67 @@ import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| import mindspore.ops as ops | |||
| class NetArgmax(nn.Cell): | |||
| def __init__(self): | |||
| def __init__(self, axis=0): | |||
| super(NetArgmax, self).__init__() | |||
| axis1 = 0 | |||
| axis2 = -1 | |||
| self.argmax1 = P.Argmax(axis1, output_type=mstype.int32) | |||
| self.argmax2 = P.Argmax(axis2, output_type=mstype.int32) | |||
| self.argmax3 = P.Argmax(output_type=mstype.int32) | |||
| self.argmax = ops.Argmax(axis, output_type=mstype.int32) | |||
| def construct(self, x): | |||
| return (self.argmax1(x), self.argmax2(x), self.argmax3(x)) | |||
| return self.argmax(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_argmax(): | |||
| def test_argmax_1d(): | |||
| for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: | |||
| context.set_context(mode=mode, device_target="GPU") | |||
| x = Tensor(np.array([1., 20., 5.]).astype(np.float32)) | |||
| Argmax = NetArgmax(axis=0) | |||
| output = Argmax(x) | |||
| expect = np.array([1]).astype(np.float32) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_argmax_2d(): | |||
| for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: | |||
| context.set_context(mode=mode, device_target="GPU") | |||
| x = Tensor(np.array([[1., 20., 5.], | |||
| [67., 8., 9.], | |||
| [130., 24., 15.], | |||
| [0.3, -0.4, -15.]]).astype(np.float32)) | |||
| expect1 = np.array([2, 2, 2]).astype(np.int32) | |||
| expect2 = np.array([1, 0, 0, 0]).astype(np.int32) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| argmax = NetArgmax() | |||
| output = argmax(x) | |||
| assert (output[0].asnumpy() == expect1).all() | |||
| assert (output[1].asnumpy() == expect2).all() | |||
| assert (output[2].asnumpy() == expect2).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| argmax1 = NetArgmax() | |||
| output1 = argmax1(x) | |||
| assert (output1[0].asnumpy() == expect1).all() | |||
| assert (output1[1].asnumpy() == expect2).all() | |||
| assert (output1[2].asnumpy() == expect2).all() | |||
| Argmax_axis_0 = NetArgmax(axis=0) | |||
| output = Argmax_axis_0(x) | |||
| expect = np.array([2, 2, 2]).astype(np.int32) | |||
| assert (output.asnumpy() == expect).all() | |||
| Argmax_axis_1 = NetArgmax(axis=1) | |||
| output = Argmax_axis_1(x) | |||
| expect = np.array([1, 0, 0, 0]).astype(np.int32) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_argmax_high_dims(): | |||
| for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: | |||
| context.set_context(mode=mode, device_target="GPU") | |||
| for dim in range(3, 10): | |||
| shape = np.random.randint(1, 10, size=dim) | |||
| x = np.random.randn(reduce(lambda x, y: x * y, shape)).astype(np.float32) | |||
| x = x.reshape(shape) | |||
| rnd_axis = random.randint(-dim + 1, dim - 1) | |||
| Argmax = NetArgmax(axis=rnd_axis) | |||
| ms_output = Argmax(Tensor(x)) | |||
| np_output = np.argmax(x, axis=rnd_axis) | |||
| assert (ms_output.asnumpy() == np_output).all() | |||