Browse Source

fix cpu/gpu argmax op

tags/v1.2.0-rc1
xcnick 4 years ago
parent
commit
d65a5affba
8 changed files with 229 additions and 182 deletions
  1. +66
    -32
      mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.cc
  2. +9
    -4
      mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h
  3. +4
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.cc
  4. +28
    -38
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h
  5. +25
    -61
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cu
  6. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh
  7. +44
    -16
      tests/st/ops/cpu/test_argmax_op.py
  8. +51
    -25
      tests/st/ops/gpu/test_argmax_op.py

+ 66
- 32
mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.cc View File

@@ -18,48 +18,82 @@


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void ArgmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (shape.size() != 2) {
MS_LOG(EXCEPTION) << "argmax kernel dims invalid " << shape.size();
namespace {
size_t get_element_num(const std::vector<size_t> &shape) {
size_t size = 1;
for (size_t i = 0; i < shape.size(); i++) {
size *= shape[i];
} }
batch_size_ = shape[0];
class_num_ = shape[1];
return size;
}


int64_t axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
if (axis != -1 && axis != 1) {
MS_LOG(EXCEPTION) << "argmax kernel not support axis " << axis;
template <typename T>
bool check_validation(const std::vector<size_t> &shape, const size_t num_before_axis, const size_t num_after_axis,
const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(EXCEPTION) << "Wrong number of inputs or outputs!";
return false;
} }
size_t data_size = sizeof(T);
size_t input_size = get_element_num(shape) * data_size;
size_t output_num = num_before_axis * num_after_axis;
size_t output_size = output_num * sizeof(int);
if (inputs[0]->size != input_size || outputs[0]->size != output_size) {
MS_LOG(EXCEPTION) << "invalid input or output data size!";
return false;
}
return true;
} }
} // namespace


bool ArgmaxCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspaces*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.empty() || outputs.empty()) {
MS_LOG(EXCEPTION) << "input or output empty!";
template <typename T>
void ArgmaxCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
size_t shape_len = shape_.size();
int64_t axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
axis += shape_len;
if (axis < 0) {
MS_LOG(EXCEPTION) << "Invalid axis:" << axis << ", should in range [-1, " << shape_len - 1 << "]";
}
axis = axis % static_cast<int64_t>(shape_len);
num_before_axis_ = 1;
num_after_axis_ = 1;
for (size_t i = 0; i < shape_len; i++) {
if (static_cast<int64_t>(i) < axis) {
num_before_axis_ *= shape_[i];
} else if (static_cast<int64_t>(i) > axis) {
num_after_axis_ *= shape_[i];
}
} }
dim_axis_ = shape_[axis];
}


size_t batch_float_size = batch_size_ * sizeof(float);
size_t batch_class_float_size = class_num_ * batch_float_size;
if (inputs[0]->size != batch_class_float_size || outputs[0]->size != batch_float_size) {
MS_LOG(EXCEPTION) << "invalid input or output data size!";
template <typename T>
bool ArgmaxCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspaces*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (!check_validation<T>(shape_, num_before_axis_, num_after_axis_, inputs, outputs)) {
return false;
} }
auto input = reinterpret_cast<float *>(inputs[0]->addr);
auto output = reinterpret_cast<int *>(outputs[0]->addr);
size_t row_start = 0;
for (size_t i = 0; i < batch_size_; ++i) {
size_t max_index = 0;
float max_value = input[row_start];
for (size_t j = 1; j < class_num_; ++j) {
size_t index = row_start + j;
if (input[index] > max_value) {
max_value = input[index];
max_index = j;

auto input = reinterpret_cast<T *>(inputs[0]->addr);
auto output = reinterpret_cast<int32_t *>(outputs[0]->addr);

for (size_t i = 0; i < num_before_axis_; i++) {
size_t src_index_i = i * dim_axis_ * num_after_axis_;
for (size_t j = 0; j < num_after_axis_; j++) {
std::vector<float> array_axis;
size_t src_index_j = src_index_i + j;
for (size_t k = 0; k < dim_axis_; k++) {
size_t src_index_k = k * num_after_axis_ + src_index_j;
array_axis.push_back(static_cast<float>(input[src_index_k]));
} }
auto max_ops = std::max_element(array_axis.begin(), array_axis.end());
auto max_index = static_cast<int32_t>(std::distance(array_axis.begin(), max_ops));
auto dst_index = i * num_after_axis_ + j;
output[dst_index] = max_index;
} }
output[i] = SizeToInt(max_index);
row_start += class_num_;
} }
return true; return true;
} }


+ 9
- 4
mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h View File

@@ -22,6 +22,7 @@


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T>
class ArgmaxCPUKernel : public CPUKernel { class ArgmaxCPUKernel : public CPUKernel {
public: public:
ArgmaxCPUKernel() = default; ArgmaxCPUKernel() = default;
@@ -33,12 +34,16 @@ class ArgmaxCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;


private: private:
size_t class_num_{0};
size_t batch_size_{0};
std::vector<size_t> shape_;
size_t num_before_axis_;
size_t num_after_axis_;
size_t dim_axis_;
}; };


MS_REG_CPU_KERNEL(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
ArgmaxCPUKernel);
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
ArgmaxCPUKernel, float);
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
ArgmaxCPUKernel, float16);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore




+ 4
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.cc View File

@@ -18,9 +18,9 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
ArgmaxGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
ArgmaxGpuKernel, half)
MS_REG_GPU_KERNEL_TWO(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
ArgmaxGpuKernel, float, int)
MS_REG_GPU_KERNEL_TWO(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
ArgmaxGpuKernel, half, int)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 28
- 38
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h View File

@@ -23,11 +23,10 @@
#include "backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
#define ARGMAX_MAX_DIMENSION 2
template <typename T>
template <typename T, typename S>
class ArgmaxGpuKernel : public GpuKernel { class ArgmaxGpuKernel : public GpuKernel {
public: public:
ArgmaxGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0), batch_size_(0), channel_size_(0), axis_(0) {}
ArgmaxGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0), bound_(0), outer_size_(0), inner_size_(0) {}
~ArgmaxGpuKernel() override = default; ~ArgmaxGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@@ -37,47 +36,38 @@ class ArgmaxGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override { const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0); T *input = GetDeviceAddress<T>(inputs, 0);
int *output = GetDeviceAddress<int>(outputs, 0);
CalArgmax(input, SizeToInt(batch_size_), SizeToInt(channel_size_), axis_, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
S *output = GetDeviceAddress<S>(outputs, 0);
CalArgmax(input, bound_, outer_size_, inner_size_, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true; return true;
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but argmax needs 1 input.";
return false;
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
int64_t dims = shape.size();
int64_t axis = GetAttr<int64_t>(kernel_node, "axis");
if (axis < 0) {
axis += dims;
} }
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but argmax needs 1 output.";
return false;
input_size_ = sizeof(T);
for (auto x : shape) {
input_size_ *= x;
} }
auto output_type = GetValue<TypePtr>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("output_type"));
if (output_type->type_id() != TypeId::kNumberTypeInt32) {
MS_LOG(EXCEPTION) << "Argmax only supports int32 output type.";
output_size_ = sizeof(S);
for (auto x : output_shape) {
output_size_ *= x;
} }
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() > ARGMAX_MAX_DIMENSION) {
MS_LOG(EXCEPTION) << "Input is " << input_shape.size() << "-D, but Argmax supports max " << ARGMAX_MAX_DIMENSION
<< "-D inputs.";
bound_ = static_cast<S>(shape[axis]);
if (shape[axis] != static_cast<size_t>(bound_)) {
MS_LOG(EXCEPTION) << "Bound's shape is larger than index type and overflows when casting.";
} }
axis_ = GetAttr<int64_t>(kernel_node, "axis");
if (axis_ < 0) {
axis_ += static_cast<int64_t>(input_shape.size());
outer_size_ = 1;
for (int64_t i = axis - 1; i >= 0; i--) {
outer_size_ *= shape[i];
} }
if (input_shape.size() == 1) {
batch_size_ = 0;
channel_size_ = input_shape[0];
input_size_ = sizeof(T) * channel_size_;
output_size_ = sizeof(int);
} else {
batch_size_ = input_shape[0];
channel_size_ = input_shape[1];
input_size_ = sizeof(T) * batch_size_ * channel_size_;
output_size_ = (axis_ == 1) ? sizeof(int) * batch_size_ : sizeof(int) * channel_size_;
inner_size_ = 1;
for (int64_t i = axis + 1; i < dims; i++) {
inner_size_ *= shape[i];
} }
InitSizeLists(); InitSizeLists();
return true; return true;
@@ -96,9 +86,9 @@ class ArgmaxGpuKernel : public GpuKernel {
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_; std::vector<size_t> workspace_size_list_;
size_t batch_size_;
size_t channel_size_;
int64_t axis_;
S bound_;
size_t outer_size_;
size_t inner_size_;
}; };
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore


+ 25
- 61
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cu View File

@@ -17,72 +17,36 @@
#include "argmax_impl.cuh" #include "argmax_impl.cuh"
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
#include "include/cuda_fp16.h" #include "include/cuda_fp16.h"
template <typename T>
__global__ void Argmax1D(const T *input, const int channel_size, int *output) {
int max_index = 0;
T max = input[0];
for (int pos = 1; pos < channel_size; pos++) {
if (max < input[pos]) {
max = input[pos];
max_index = pos;
template <typename T, typename S>
__global__ void Argmax(const T *input, const S bound, const size_t outer_size,
const size_t inner_size, S *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outer_size * inner_size;
pos += gridDim.x * blockDim.x) {
size_t x = pos / inner_size % outer_size;
size_t y = pos % inner_size;
S idx = 0;
size_t input_offset = x * bound * inner_size + 0 * inner_size + y;
T max_data = input[input_offset];
for (S i = 1; i < bound; i++) {
input_offset = x * bound * inner_size + i * inner_size + y;
auto input_data = input[input_offset];
idx = input_data > max_data ? i : idx;
max_data = input_data > max_data ? input_data : max_data;
} }
output[pos] = idx;
} }
output[0] = max_index;
return; return;
} }
template <typename T>
__global__ void ArgmaxDefault2D(const T *input, const int batch_size, const int channel_size, int *output) {
int pos;
int max_index;
T max;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += blockDim.x * gridDim.x) {
max = input[i * channel_size];
max_index = 0;
for (int j = 1; j < channel_size; j++) {
pos = i * channel_size + j;
if (max < input[pos]) {
max = input[pos];
max_index = j;
}
}
output[i] = max_index;
}
return;
}
template <typename T>
__global__ void ArgmaxAxis2D(const T *input, const int batch_size, const int channel_size, int *output) {
int pos;
int max_index;
T max;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) {
max = input[i];
max_index = 0;
for (int j = 1; j < batch_size; j++) {
pos = j * channel_size + i;
if (max < input[pos]) {
max = input[pos];
max_index = j;
}
}
output[i] = max_index;
}
return;
}
template <typename T>
void CalArgmax(const T *input, const int batch_size, const int channel_size, const int64_t axis, int *output,
cudaStream_t cuda_stream) {
if (batch_size == 0) {
Argmax1D<<<1, 1, 0, cuda_stream>>>(input, channel_size, output);
} else if (axis == 1) {
ArgmaxDefault2D<<<GET_BLOCKS(batch_size), GET_THREADS, 0, cuda_stream>>>(input, batch_size, channel_size, output);
} else {
ArgmaxAxis2D<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(input, batch_size, channel_size, output);
}
template <typename T, typename S>
void CalArgmax(const T *input, const S bound, const size_t outer_size, const size_t inner_size,
S *output, cudaStream_t cuda_stream) {
Argmax<<<GET_BLOCKS(outer_size), GET_THREADS, 0, cuda_stream>>>(input, bound, outer_size, inner_size,
output);
return; return;
} }
template void CalArgmax<float>(const float *input, const int batch_size, const int channel_size, const int64_t axis,
int *output, cudaStream_t cuda_stream);
template void CalArgmax<half>(const half *input, const int batch_size, const int channel_size, const int64_t axis,
int *output, cudaStream_t cuda_stream);
template void CalArgmax<float, int>(const float *input, const int bound, const size_t outer_size,
const size_t inner_size, int *output, cudaStream_t cuda_stream);
template void CalArgmax<half, int>(const half *input, const int bound, const size_t outer_size,
const size_t inner_size, int *output, cudaStream_t cuda_stream);

+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh View File

@@ -16,8 +16,8 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_ARGMAX_IMPL_CUH_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_ARGMAX_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_ARGMAX_IMPL_CUH_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_ARGMAX_IMPL_CUH_
template <typename T>
void CalArgmax(const T *input, const int batch_size, const int channel_size, const int64_t axis, int *output,
template <typename T, typename S>
void CalArgmax(const T *input, const S bound, const size_t outer_size, const size_t inner_size, S *output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_ARGMAX_IMPL_CUH_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_ARGMAX_IMPL_CUH_

+ 44
- 16
tests/st/ops/cpu/test_argmax_op.py View File

@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import random
from functools import reduce
import numpy as np import numpy as np
import pytest import pytest
@@ -20,33 +22,59 @@ import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
import mindspore.ops as ops
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetArgmax(nn.Cell): class NetArgmax(nn.Cell):
def __init__(self):
def __init__(self, axis=0):
super(NetArgmax, self).__init__() super(NetArgmax, self).__init__()
self.argmax = P.Argmax(output_type=mstype.int32)
x = Tensor(np.array([[1., 20., 5.],
[67., 8., 9.],
[130., 24., 15.]]).astype(np.float32))
self.x = Parameter(initializer(x, x.shape), name='x')
self.argmax = ops.Argmax(axis=axis, output_type=mstype.int32)
def construct(self):
return self.argmax(self.x)
def construct(self, x):
return self.argmax(x)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_argmax():
Argmax = NetArgmax()
output = Argmax()
print("================================")
def test_argmax_1d():
x = Tensor(np.array([1., 20., 5.]).astype(np.float32))
Argmax = NetArgmax(axis=0)
output = Argmax(x)
expect = np.array([1]).astype(np.float32)
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_argmax_2d():
x = Tensor(np.array([[1., 20., 5.],
[67., 8., 9.],
[130., 24., 15.]]).astype(np.float32))
Argmax_axis_0 = NetArgmax(axis=0)
output = Argmax_axis_0(x)
expect = np.array([2, 2, 2]).astype(np.float32)
assert (output.asnumpy() == expect).all()
Argmax_axis_1 = NetArgmax(axis=1)
output = Argmax_axis_1(x)
expect = np.array([1, 0, 0]).astype(np.float32) expect = np.array([1, 0, 0]).astype(np.float32)
print(output)
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_argmax_high_dims():
for dim in range(3, 10):
shape = np.random.randint(1, 10, size=dim)
x = np.random.randn(reduce(lambda x, y: x * y, shape)).astype(np.float32)
x = x.reshape(shape)
rnd_axis = random.randint(-dim + 1, dim - 1)
Argmax = NetArgmax(axis=rnd_axis)
ms_output = Argmax(Tensor(x))
np_output = np.argmax(x, axis=rnd_axis)
assert (ms_output.asnumpy() == np_output).all()

+ 51
- 25
tests/st/ops/gpu/test_argmax_op.py View File

@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import random
from functools import reduce
import numpy as np import numpy as np
import pytest import pytest
@@ -20,43 +22,67 @@ import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
import mindspore.ops as ops
class NetArgmax(nn.Cell): class NetArgmax(nn.Cell):
def __init__(self):
def __init__(self, axis=0):
super(NetArgmax, self).__init__() super(NetArgmax, self).__init__()
axis1 = 0
axis2 = -1
self.argmax1 = P.Argmax(axis1, output_type=mstype.int32)
self.argmax2 = P.Argmax(axis2, output_type=mstype.int32)
self.argmax3 = P.Argmax(output_type=mstype.int32)
self.argmax = ops.Argmax(axis, output_type=mstype.int32)
def construct(self, x): def construct(self, x):
return (self.argmax1(x), self.argmax2(x), self.argmax3(x))
return self.argmax(x)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_argmax():
def test_argmax_1d():
for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]:
context.set_context(mode=mode, device_target="GPU")
x = Tensor(np.array([1., 20., 5.]).astype(np.float32))
Argmax = NetArgmax(axis=0)
output = Argmax(x)
expect = np.array([1]).astype(np.float32)
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmax_2d():
for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]:
context.set_context(mode=mode, device_target="GPU")
x = Tensor(np.array([[1., 20., 5.], x = Tensor(np.array([[1., 20., 5.],
[67., 8., 9.], [67., 8., 9.],
[130., 24., 15.], [130., 24., 15.],
[0.3, -0.4, -15.]]).astype(np.float32)) [0.3, -0.4, -15.]]).astype(np.float32))
expect1 = np.array([2, 2, 2]).astype(np.int32)
expect2 = np.array([1, 0, 0, 0]).astype(np.int32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
argmax = NetArgmax()
output = argmax(x)
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()
assert (output[2].asnumpy() == expect2).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
argmax1 = NetArgmax()
output1 = argmax1(x)
assert (output1[0].asnumpy() == expect1).all()
assert (output1[1].asnumpy() == expect2).all()
assert (output1[2].asnumpy() == expect2).all()
Argmax_axis_0 = NetArgmax(axis=0)
output = Argmax_axis_0(x)
expect = np.array([2, 2, 2]).astype(np.int32)
assert (output.asnumpy() == expect).all()
Argmax_axis_1 = NetArgmax(axis=1)
output = Argmax_axis_1(x)
expect = np.array([1, 0, 0, 0]).astype(np.int32)
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmax_high_dims():
for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]:
context.set_context(mode=mode, device_target="GPU")
for dim in range(3, 10):
shape = np.random.randint(1, 10, size=dim)
x = np.random.randn(reduce(lambda x, y: x * y, shape)).astype(np.float32)
x = x.reshape(shape)
rnd_axis = random.randint(-dim + 1, dim - 1)
Argmax = NetArgmax(axis=rnd_axis)
ms_output = Argmax(Tensor(x))
np_output = np.argmax(x, axis=rnd_axis)
assert (ms_output.asnumpy() == np_output).all()

Loading…
Cancel
Save