Browse Source

add CPU ops: Greater/GreaterEqual/Range/GatherNd for center net

tags/v1.2.0-rc1
caojian05 CaoJian 5 years ago
parent
commit
06fb28c703
14 changed files with 734 additions and 5 deletions
  1. +27
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc
  2. +26
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h
  3. +6
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h
  4. +104
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/gathernd_cpu_kernel.cc
  5. +67
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/gathernd_cpu_kernel.h
  6. +56
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/range_cpu_kernel.cc
  7. +54
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/range_cpu_kernel.h
  8. +1
    -1
      mindspore/nn/layer/math.py
  9. +1
    -1
      mindspore/ops/operations/array_ops.py
  10. +2
    -2
      mindspore/ops/operations/math_ops.py
  11. +188
    -0
      tests/st/ops/cpu/test_gathernd_op.py
  12. +70
    -0
      tests/st/ops/cpu/test_greater_equal_op.py
  13. +70
    -0
      tests/st/ops/cpu/test_greater_op.py
  14. +62
    -0
      tests/st/ops/cpu/test_range_op.py

+ 27
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc View File

@@ -167,6 +167,24 @@ void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, T
} }
} }


template <typename T>
void ArithmeticCPUKernel::Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] > input2[idx[1]];
}
}

template <typename T>
void ArithmeticCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] >= input2[idx[1]];
}
}

void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
@@ -190,6 +208,10 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = EQUAL; operate_type_ = EQUAL;
} else if (kernel_name == prim::kPrimNotEqual->name()) { } else if (kernel_name == prim::kPrimNotEqual->name()) {
operate_type_ = NOTEQUAL; operate_type_ = NOTEQUAL;
} else if (kernel_name == prim::kPrimGreater->name()) {
operate_type_ = GREATER;
} else if (kernel_name == prim::kPrimGreaterEqual->name()) {
operate_type_ = GREATEREQUAL;
} else if (kernel_name == prim::kPrimAssignAdd->name()) { } else if (kernel_name == prim::kPrimAssignAdd->name()) {
operate_type_ = ASSIGNADD; operate_type_ = ASSIGNADD;
} else if (kernel_name == prim::kPrimSquaredDifference->name()) { } else if (kernel_name == prim::kPrimSquaredDifference->name()) {
@@ -301,6 +323,11 @@ void ArithmeticCPUKernel::LaunchKernelLogic(const std::vector<AddressPtr> &input
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Equal<T>, this, input1, input2, output, start, end)); threads.emplace_back(std::thread(&ArithmeticCPUKernel::Equal<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == NOTEQUAL) { } else if (operate_type_ == NOTEQUAL) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::NotEqual<T>, this, input1, input2, output, start, end)); threads.emplace_back(std::thread(&ArithmeticCPUKernel::NotEqual<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == GREATER) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Greater<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == GREATEREQUAL) {
threads.emplace_back(
std::thread(&ArithmeticCPUKernel::GreaterEqual<T>, this, input1, input2, output, start, end));
} else { } else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_; MS_LOG(EXCEPTION) << "Not support " << operate_type_;
} }


+ 26
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h View File

@@ -63,6 +63,10 @@ class ArithmeticCPUKernel : public CPUKernel {
void NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); void NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
template <typename T> template <typename T>
void SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end); void SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end);
template <typename T>
void GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
std::vector<size_t> input_shape0_; std::vector<size_t> input_shape0_;
std::vector<size_t> input_shape1_; std::vector<size_t> input_shape1_;
std::vector<size_t> input_element_num0_; std::vector<size_t> input_element_num0_;
@@ -213,6 +217,28 @@ MS_REG_CPU_KERNEL(
SquaredDifference, SquaredDifference,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCPUKernel); ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
Greater, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
Greater,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
Greater, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore




+ 6
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h View File

@@ -53,6 +53,9 @@ const char END[] = "end";
const char SIZE[] = "size"; const char SIZE[] = "size";
const char USE_NESTEROV[] = "use_nesterov"; const char USE_NESTEROV[] = "use_nesterov";
const char GROUP[] = "group"; const char GROUP[] = "group";
const char START[] = "start";
const char LIMIT[] = "limit";
const char DELTA[] = "delta";
enum OperateType { enum OperateType {
ADD = 0, ADD = 0,
@@ -79,7 +82,9 @@ enum OperateType {
EQUAL, EQUAL,
NOTEQUAL, NOTEQUAL,
FLOOR, FLOOR,
SQUAREDDIFFERENCE
SQUAREDDIFFERENCE,
GREATER,
GREATEREQUAL,
}; };
class CPUKernel : public kernel::KernelMod { class CPUKernel : public kernel::KernelMod {


+ 104
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/gathernd_cpu_kernel.cc View File

@@ -0,0 +1,104 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/gathernd_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"

namespace mindspore {
namespace kernel {

void GatherNdCPUKernel::InitKernel(const CNodePtr &kernel_node) {
input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);

dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);

// ReShape()
size_t dim_of_indices = 1;
for (size_t i = 0; i < indices_shapes_.size() - IntToSize(1); ++i) {
dim_of_indices *= indices_shapes_[i];
}

size_t dim_after_indices = 1;
size_t dim_indices_last = indices_shapes_[indices_shapes_.size() - IntToSize(1)];
for (size_t i = dim_indices_last; i < input_shapes_.size(); i++) {
dim_after_indices *= input_shapes_[i];
}

dims_.emplace_back(dim_of_indices);
dims_.emplace_back(dim_after_indices);
dims_.emplace_back(dim_indices_last);

batch_strides_.resize(dim_indices_last, 0);
batch_indices_.resize(dim_indices_last, 0);

if (dim_indices_last > 0) {
batch_strides_[dim_indices_last - 1] = input_shapes_[dim_indices_last - 1];
batch_indices_[dim_indices_last - 1] = dims_[1];
}

for (size_t i = dim_indices_last - 1; i > 0; --i) {
batch_strides_[i - 1] = input_shapes_[i - 1];
batch_indices_[i - 1] = batch_indices_[i] * input_shapes_[i];
}
}

bool GatherNdCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeInt32) {
return LaunchKernel<int32_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt64) {
return LaunchKernel<int64_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
return LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
return LaunchKernel<double>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "Only support int, float, but actual data type is " << TypeIdLabel(dtype_);
}
}

template <typename T>
bool GatherNdCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);

//
size_t output_dim0 = dims_[0];
size_t output_dim1 = dims_[1];
size_t indices_dim1 = dims_[2];

int num = output_dim0 * output_dim1;

for (int write_index = 0; write_index < num; write_index++) {
int i = write_index / output_dim1 % output_dim0;
int j = write_index % output_dim1;

int read_index = 0;
for (size_t k = 0; k < indices_dim1; k++) {
size_t ind = indices_dim1 * i + k;
int indices_i = indices_addr[ind];
read_index += indices_i * batch_indices_[k];
}
read_index += j;
output_addr[write_index] = input_addr[read_index];
}
return true;
}
} // namespace kernel
} // namespace mindspore

+ 67
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/gathernd_cpu_kernel.h View File

@@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERND_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERND_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"

namespace mindspore {
namespace kernel {
class GatherNdCPUKernel : public CPUKernel {
public:
GatherNdCPUKernel() = default;
~GatherNdCPUKernel() override = default;

void InitKernel(const CNodePtr &kernel_node) override;

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);

private:
std::vector<size_t> input_shapes_;
std::vector<size_t> indices_shapes_;
std::vector<size_t> output_shapes_;

std::vector<size_t> dims_;
std::vector<int> batch_indices_;
std::vector<int> batch_strides_;

TypeId dtype_{kTypeUnknown};
};

MS_REG_CPU_KERNEL(
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GatherNdCPUKernel);
MS_REG_CPU_KERNEL(
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
GatherNdCPUKernel);
MS_REG_CPU_KERNEL(
GatherNd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherNdCPUKernel);
MS_REG_CPU_KERNEL(
GatherNd,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
GatherNdCPUKernel);
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERND_CPU_KERNEL_H_

+ 56
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/range_cpu_kernel.cc View File

@@ -0,0 +1,56 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/range_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"

namespace mindspore {
namespace kernel {
void RangeCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);

start_ = AnfAlgo::GetNodeAttr<float>(kernel_node, START);
limit_ = AnfAlgo::GetNodeAttr<float>(kernel_node, LIMIT);
delta_ = AnfAlgo::GetNodeAttr<float>(kernel_node, DELTA);
}

bool RangeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeInt32) {
return LaunchKernel<int32_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt64) {
return LaunchKernel<int64_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
return LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
return LaunchKernel<double>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "Only support int, float, but actual data type is " << TypeIdLabel(dtype_);
}
}

template <typename T>
bool RangeCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t elem_num = outputs[0]->size / sizeof(T);
for (size_t i = 0; i < elem_num; i++) {
output_addr[i] = start_ + i * delta_;
}
return true;
}
} // namespace kernel
} // namespace mindspore

+ 54
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/range_cpu_kernel.h View File

@@ -0,0 +1,54 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANGE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANGE_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"

namespace mindspore {
namespace kernel {
class RangeCPUKernel : public CPUKernel {
public:
RangeCPUKernel() = default;
~RangeCPUKernel() override = default;

void InitKernel(const CNodePtr &kernel_node) override;

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);

private:
TypeId dtype_{kTypeUnknown};
int64_t start_;
int64_t limit_;
int64_t delta_;
};

MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), RangeCPUKernel);
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), RangeCPUKernel);
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
RangeCPUKernel);
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
RangeCPUKernel);

} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANGE_CPU_KERNEL_H_

+ 1
- 1
mindspore/nn/layer/math.py View File

@@ -116,7 +116,7 @@ class Range(Cell):
Tensor, the dtype is int if the dtype of `start`, `limit` and `delta` all are int. Otherwise, dtype is float. Tensor, the dtype is int if the dtype of `start`, `limit` and `delta` all are int. Otherwise, dtype is float.


Supported Platforms: Supported Platforms:
``Ascend``
``Ascend`` ``CPU``


Examples: Examples:
>>> net = nn.Range(1, 8, 2) >>> net = nn.Range(1, 8, 2)


+ 1
- 1
mindspore/ops/operations/array_ops.py View File

@@ -3078,7 +3078,7 @@ class GatherNd(PrimitiveWithInfer):
Tensor, has the same type as `input_x` and the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:]. Tensor, has the same type as `input_x` and the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:].


Supported Platforms: Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``


Examples: Examples:
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)


+ 2
- 2
mindspore/ops/operations/math_ops.py View File

@@ -2698,7 +2698,7 @@ class Greater(_LogicBinaryOp):
Tensor, the shape is the same as the one after broadcasting,and the data type is bool. Tensor, the shape is the same as the one after broadcasting,and the data type is bool.


Supported Platforms: Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``


Examples: Examples:
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32) >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32)
@@ -2739,7 +2739,7 @@ class GreaterEqual(_LogicBinaryOp):
Tensor, the shape is the same as the one after broadcasting,and the data type is bool. Tensor, the shape is the same as the one after broadcasting,and the data type is bool.


Supported Platforms: Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``


Examples: Examples:
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32) >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32)


+ 188
- 0
tests/st/ops/cpu/test_gathernd_op.py View File

@@ -0,0 +1,188 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P

context.set_context(mode=context.GRAPH_MODE, device_target='CPU')


class OpNetWrapper(nn.Cell):
def __init__(self, op):
super(OpNetWrapper, self).__init__()
self.op = op

def construct(self, *inputs):
return self.op(*inputs)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_case1_basic_func():
op = P.GatherNd()
op_wrapper = OpNetWrapper(op)

indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
params = Tensor(np.array([[0, 1], [2, 3]]), mindspore.float32)
outputs = op_wrapper(params, indices)
print(outputs)
expected = [0, 3]
assert np.allclose(outputs.asnumpy(), np.array(expected))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_case2_indices_to_matrix():
op = P.GatherNd()
op_wrapper = OpNetWrapper(op)

indices = Tensor(np.array([[1], [0]]), mindspore.int32)
params = Tensor(np.array([[0, 1], [2, 3]]), mindspore.float32)
outputs = op_wrapper(params, indices)
print(outputs)
expected = [[2, 3], [0, 1]]
assert np.allclose(outputs.asnumpy(), np.array(expected))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_case3_indices_to_3d_tensor():
op = P.GatherNd()
op_wrapper = OpNetWrapper(op)

indices = Tensor(np.array([[1]]), mindspore.int32) # (1, 1)
params = Tensor(np.array([[[0, 1], [2, 3]],
[[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2)
outputs = op_wrapper(params, indices)
print(outputs)
expected = [[[4, 5], [6, 7]]] # (1, 2, 2)
assert np.allclose(outputs.asnumpy(), np.array(expected))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_case4():
op = P.GatherNd()
op_wrapper = OpNetWrapper(op)

indices = Tensor(np.array([[0, 1], [1, 0]]), mindspore.int32) # (2, 2)
params = Tensor(np.array([[[0, 1], [2, 3]],
[[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2)
outputs = op_wrapper(params, indices)
print(outputs)
expected = [[2, 3], [4, 5]] # (2, 2)
assert np.allclose(outputs.asnumpy(), np.array(expected))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_case5():
op = P.GatherNd()
op_wrapper = OpNetWrapper(op)

indices = Tensor(np.array([[0, 0, 1], [1, 0, 1]]), mindspore.int32) # (2, 3)
params = Tensor(np.array([[[0, 1], [2, 3]],
[[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2)
outputs = op_wrapper(params, indices)
print(outputs)
expected = [1, 5] # (2,)
assert np.allclose(outputs.asnumpy(), np.array(expected))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_case6():
op = P.GatherNd()
op_wrapper = OpNetWrapper(op)

indices = Tensor(np.array([[[0, 0]], [[0, 1]]]), mindspore.int32) # (2, 1, 2)
params = Tensor(np.array([[[0, 1], [2, 3]],
[[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2)
outputs = op_wrapper(params, indices)
print(outputs)
expected = [[[0, 1]], [[2, 3]]] # (2, 1, 2)
assert np.allclose(outputs.asnumpy(), np.array(expected))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_case7():
op = P.GatherNd()
op_wrapper = OpNetWrapper(op)

indices = Tensor(np.array([[[1]], [[0]]]), mindspore.int32) # (2, 1, 1)
params = Tensor(np.array([[[0, 1], [2, 3]],
[[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2)
outputs = op_wrapper(params, indices)
print(outputs)
expected = [[[[4, 5], [6, 7]]], [[[0, 1], [2, 3]]]] # (2, 1, 2, 2)
assert np.allclose(outputs.asnumpy(), np.array(expected))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_case8():
op = P.GatherNd()
op_wrapper = OpNetWrapper(op)

indices = Tensor(np.array([[[0, 1], [1, 0]], [[0, 0], [1, 1]]]), mindspore.int32) # (2, 2, 2)
params = Tensor(np.array([[[0, 1], [2, 3]],
[[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2)
outputs = op_wrapper(params, indices)
print(outputs)
expected = [[[2, 3], [4, 5]], [[0, 1], [6, 7]]] # (2, 2, 2)
assert np.allclose(outputs.asnumpy(), np.array(expected))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_case9():
op = P.GatherNd()
op_wrapper = OpNetWrapper(op)

indices = Tensor(np.array([[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]), mindspore.int32) # (2, 2, 3)
params = Tensor(np.array([[[0, 1], [2, 3]],
[[4, 5], [6, 7]]]), mindspore.int64) # (2, 2, 2)
outputs = op_wrapper(params, indices)
print(outputs)
expected = [[1, 5], [3, 6]] # (2, 2, 2)
assert np.allclose(outputs.asnumpy(), np.array(expected))


if __name__ == '__main__':
test_case1_basic_func()
test_case2_indices_to_matrix()
test_case3_indices_to_3d_tensor()
test_case4()
test_case5()
test_case6()
test_case7()
test_case8()
test_case9()

+ 70
- 0
tests/st/ops/cpu/test_greater_equal_op.py View File

@@ -0,0 +1,70 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P

context.set_context(mode=context.GRAPH_MODE, device_target='CPU')


class OpNetWrapper(nn.Cell):
def __init__(self, op):
super(OpNetWrapper, self).__init__()
self.op = op

def construct(self, *inputs):
return self.op(*inputs)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_int32():
op = P.GreaterEqual()
op_wrapper = OpNetWrapper(op)

input_x = Tensor(np.array([1, 2, 3]).astype(np.int32))
input_y = Tensor(np.array([3, 2, 1]).astype(np.int32))
outputs = op_wrapper(input_x, input_y)

print(outputs)
assert outputs.shape == (3,)
assert np.allclose(outputs.asnumpy(), [False, True, True])


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_float32():
op = P.GreaterEqual()
op_wrapper = OpNetWrapper(op)

input_x = Tensor(np.array([1, 2, -1]).astype(np.float32))
input_y = Tensor(np.array([-3, 2, -1]).astype(np.float32))
outputs = op_wrapper(input_x, input_y)

print(outputs)
assert outputs.shape == (3,)
assert np.allclose(outputs.asnumpy(), [True, True, True])


if __name__ == '__main__':
test_int32()
test_float32()

+ 70
- 0
tests/st/ops/cpu/test_greater_op.py View File

@@ -0,0 +1,70 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P

context.set_context(mode=context.GRAPH_MODE, device_target='CPU')


class OpNetWrapper(nn.Cell):
def __init__(self, op):
super(OpNetWrapper, self).__init__()
self.op = op

def construct(self, *inputs):
return self.op(*inputs)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_int32():
op = P.Greater()
op_wrapper = OpNetWrapper(op)

input_x = Tensor(np.array([1, 2, 3]).astype(np.int32))
input_y = Tensor(np.array([3, 2, 1]).astype(np.int32))
outputs = op_wrapper(input_x, input_y)

print(outputs)
assert outputs.shape == (3,)
assert np.allclose(outputs.asnumpy(), [False, False, True])


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_float32():
op = P.Greater()
op_wrapper = OpNetWrapper(op)

input_x = Tensor(np.array([1, 2, -1]).astype(np.float32))
input_y = Tensor(np.array([-3, 2, -1]).astype(np.float32))
outputs = op_wrapper(input_x, input_y)

print(outputs)
assert outputs.shape == (3,)
assert np.allclose(outputs.asnumpy(), [True, False, False])


if __name__ == '__main__':
test_int32()
test_float32()

+ 62
- 0
tests/st/ops/cpu/test_range_op.py View File

@@ -0,0 +1,62 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn

context.set_context(mode=context.GRAPH_MODE, device_target='CPU')


class OpNetWrapper(nn.Cell):
def __init__(self, op):
super(OpNetWrapper, self).__init__()
self.op = op

def construct(self, *inputs):
return self.op(*inputs)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_int():
op = nn.Range(0, 100, 10)
op_wrapper = OpNetWrapper(op)

outputs = op_wrapper()
print(outputs)
assert outputs.shape == (10,)
assert np.allclose(outputs.asnumpy(), range(0, 100, 10))


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_float():
op = nn.Range(10., 100., 20.)
op_wrapper = OpNetWrapper(op)

outputs = op_wrapper()
print(outputs)
assert outputs.shape == (5,)
assert np.allclose(outputs.asnumpy(), [10., 30., 50., 70., 90.])


if __name__ == '__main__':
test_int()
test_float()

Loading…
Cancel
Save