Browse Source

Add Select ops for cpu

Add select ops for cpu

Add select ops for cpu

Add select ops for cpu -- remove useless methods

Add select ops for cpu -- remove useless methods

Add select ops for cpu -- remove useless methods
tags/v1.1.0
hebotao 5 years ago
parent
commit
ae68883945
4 changed files with 205 additions and 1 deletions
  1. +53
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/select_cpu_kernel.cc
  2. +66
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/select_cpu_kernel.h
  3. +1
    -1
      mindspore/ops/operations/array_ops.py
  4. +85
    -0
      tests/st/ops/cpu/test_select_op.py

+ 53
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/select_cpu_kernel.cc View File

@@ -0,0 +1,53 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/kernel_compiler/cpu/select_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"

namespace mindspore {
namespace kernel {
template <typename T>
void SelectCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SelectCpuKernel needs 3 input.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SelectCpuKernel needs 1 output.";
}
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t x : shape) {
element_num_ *= x;
}
return;
}

template <typename T>
bool SelectCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
auto input_cond = reinterpret_cast<bool *>(inputs[0]->addr);
auto input_x = reinterpret_cast<T *>(inputs[1]->addr);
auto input_y = reinterpret_cast<T *>(inputs[2]->addr);
auto output = reinterpret_cast<T *>(outputs[0]->addr);
for (size_t pos = 0; pos < element_num_; pos++) {
output[pos] = input_cond[pos] ? input_x[pos] : input_y[pos];
}
return true;
}

} // namespace kernel
} // namespace mindspore

+ 66
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/select_cpu_kernel.h View File

@@ -0,0 +1,66 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SELECT_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SELECT_CPU_KERNEL_H_

#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"

namespace mindspore {
namespace kernel {
template <typename T>
class SelectCPUKernel : public CPUKernel {
public:
SelectCPUKernel() = default;
~SelectCPUKernel() override = default;

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

void InitKernel(const CNodePtr &kernel_node) override;

private:
size_t element_num_{1};
};

MS_REG_CPU_KERNEL_T(Select,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SelectCPUKernel, float);

MS_REG_CPU_KERNEL_T(Select,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SelectCPUKernel, float16);

MS_REG_CPU_KERNEL_T(Select,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
SelectCPUKernel, int);
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SELECT_CPU_KERNEL_H_

+ 1
- 1
mindspore/ops/operations/array_ops.py View File

@@ -2533,7 +2533,7 @@ class Select(PrimitiveWithInfer):
Tensor, has the same shape as `input_x`. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.

Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> select = ops.Select()


+ 85
- 0
tests/st/ops/cpu/test_select_op.py View File

@@ -0,0 +1,85 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P


class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.select = P.Select()

def construct(self, cond_op, input_x, input_y):
return self.select(cond_op, input_x, input_y)


context.set_context(mode=context.GRAPH_MODE, device_target="CPU")


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_select_float32():
cond = np.array([[True, False], [True, False]]).astype(np.bool)
x = np.array([[1.2, 1], [1, 0]]).astype(np.float32)
y = np.array([[1, 2], [3, 4.0]]).astype(np.float32)
select = Net()
output = select(Tensor(cond), Tensor(x), Tensor(y))
print(output.asnumpy())
expect = [[1.2, 2], [1, 4.0]]
error = np.ones(shape=[2, 2]) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_select_float16():
cond = np.array([[True, False], [True, False]]).astype(np.bool)
x = np.array([[1.2, 1], [1, 0]]).astype(np.float16)
y = np.array([[1, 2], [3, 4.0]]).astype(np.float16)
select = Net()
output = select(Tensor(cond), Tensor(x), Tensor(y))
print(output.asnumpy())
expect = [[1.2, 2], [1, 4.0]]
error = np.ones(shape=[2, 2]) * 1.0e-3
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_select_int32():
cond = np.array([[True, False], [True, False]]).astype(np.bool)
x = np.array([[12, 1], [1, 0]]).astype(np.int32)
y = np.array([[1, 2], [3, 4]]).astype(np.int32)
select = Net()
output = select(Tensor(cond), Tensor(x), Tensor(y))
print(output.asnumpy())
expect = [[12, 2], [1, 4]]
error = np.ones(shape=[2, 2]) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)

Loading…
Cancel
Save