Browse Source

Update OneHot operator

feature/build-system-rewrite
Zwink 4 years ago
parent
commit
cceb9d0b81
8 changed files with 552 additions and 23 deletions
  1. +2
    -1
      mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_attr_to_input_registry.cc
  2. +56
    -13
      mindspore/ccsrc/plugin/device/cpu/kernel/one_hot_cpu_kernel.cc
  3. +364
    -1
      mindspore/ccsrc/plugin/device/cpu/kernel/one_hot_cpu_kernel.h
  4. +8
    -4
      mindspore/core/ops/one_hot.cc
  5. +1
    -1
      mindspore/core/ops/one_hot.h
  6. +1
    -0
      mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py
  7. +116
    -0
      mindspore/python/mindspore/ops/_op_impl/aicpu/one_hot.py
  8. +4
    -3
      mindspore/python/mindspore/ops/operations/nn_ops.py

+ 2
- 1
mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_attr_to_input_registry.cc View File

@@ -28,7 +28,8 @@ namespace kernel {
* ...
* }
*/
std::map<string, std::vector<std::pair<string, size_t>>> AicpuOpAttrToInputMap = {};
std::map<string, std::vector<std::pair<string, size_t>>> AicpuOpAttrToInputMap = {
{prim::kPrimOneHot->name(), {{"depth", 1}}}};
bool GetAicpuOpAttrToInputInfo(const CNodePtr &kernel_node, std::vector<std::pair<string, size_t>> *info) {
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);


+ 56
- 13
mindspore/ccsrc/plugin/device/cpu/kernel/one_hot_cpu_kernel.cc View File

@@ -15,6 +15,8 @@
*/

#include "plugin/device/cpu/kernel/one_hot_cpu_kernel.h"
#include <string>
#include <complex>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"

namespace mindspore {
@@ -22,24 +24,55 @@ namespace kernel {
namespace {
constexpr size_t kOneHotInputsNum = 3;
constexpr size_t kOneHotOutputsNum = 1;
#define INPUT_COMPUTE_CASE(DTYPE, TYPE, ODTYPE, INPUTS, OUTPUTS) \
case (DTYPE): { \
switch (ODTYPE) { \
INPUT_COMPUTE_CASE_INT(DTYPE, TYPE, ODTYPE, INPUTS, OUTPUTS) \
INPUT_COMPUTE_CASE_FLOAT(DTYPE, TYPE, ODTYPE, INPUTS, OUTPUTS) \
default: \
MS_LOG(EXCEPTION) << " For OneHot the dtype of output not support."; \
} \
break; \
}

#define INPUT_COMPUTE_CASE_INT(DTYPE, TYPE, ODTYPE, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeInt8, int8_t, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeInt16, int16_t, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeInt32, int32_t, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeInt64, int64_t, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeUInt8, uint8_t, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeUInt16, uint16_t, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeUInt32, uint32_t, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeUInt64, uint64_t, INPUTS, OUTPUTS)

#define INPUT_COMPUTE_CASE_FLOAT(DTYPE, TYPE, ODTYPE, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeComplex64, std::complex<float>, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeComplex128, std::complex<double>, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeFloat64, double, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeFloat32, float_t, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeFloat16, float16, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeBool, bool, INPUTS, OUTPUTS) \
OUTPUT_COMPUTE_CASE(TYPE, kObjectTypeString, std::string, INPUTS, OUTPUTS)

#define OUTPUT_COMPUTE_CASE(TYPE, ODTYPE, OTYPE, INPUTS, OUTPUTS) \
case (ODTYPE): { \
LaunchKernel<TYPE, OTYPE>(INPUTS, OUTPUTS); \
break; \
}
} // namespace

void OneHotCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
input_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
output_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (output_shape.size() < 2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of output should be greater than or equal to 2, but got "
<< output_shape.size() << ".";
}
int64_t axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
if (axis != -1 && LongToSize(axis) >= output_shape.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the 'axis' should be -1, or an int which is less than the dimension of output, but got "
<< axis << ", got the dimension of output " << output_shape.size();
}

if (axis == -1) {
axis_ = output_shape.size() - 1;
} else {
@@ -56,12 +89,24 @@ bool OneHotCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, c
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kOneHotInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOneHotOutputsNum, kernel_name_);
const auto *indices = reinterpret_cast<int *>(inputs[0]->addr);
auto on_value = reinterpret_cast<float *>(inputs[1]->addr)[0];
auto off_value = reinterpret_cast<float *>(inputs[2]->addr)[0];
auto *output = reinterpret_cast<float *>(outputs[0]->addr);
size_t elem_num = inputs[0]->size / sizeof(int);
switch (input_dtype_) {
INPUT_COMPUTE_CASE(kNumberTypeUInt8, uint8_t, output_dtype_, inputs, outputs);
INPUT_COMPUTE_CASE(kNumberTypeInt32, int32_t, output_dtype_, inputs, outputs);
INPUT_COMPUTE_CASE(kNumberTypeInt64, int64_t, output_dtype_, inputs, outputs);
default:
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of input 'x' "
<< TypeIdToType(input_dtype_)->ToString() << " not support.";
}
return true;
}

template <typename ID, typename OD>
void OneHotCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
const auto *indices = reinterpret_cast<ID *>(inputs[0]->addr);
auto on_value = reinterpret_cast<OD *>(inputs[1]->addr)[0];
auto off_value = reinterpret_cast<OD *>(inputs[2]->addr)[0];
auto *output = reinterpret_cast<OD *>(outputs[0]->addr);
size_t elem_num = inputs[0]->size / sizeof(ID);
auto task = [this, &indices, &on_value, &off_value, &output](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
size_t stride_num = i / stride_;
@@ -78,8 +123,6 @@ bool OneHotCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, c
}
};
ParallelLaunchAutoSearch(task, elem_num, this, &parallel_search_info_);

return true;
}
} // namespace kernel
} // namespace mindspore

+ 364
- 1
mindspore/ccsrc/plugin/device/cpu/kernel/one_hot_cpu_kernel.h View File

@@ -35,12 +35,375 @@ class OneHotCpuKernelMod : public NativeCpuKernelMod {
const std::vector<AddressPtr> &outputs) override;

private:
template <typename ID, typename OD>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);

TypeId input_dtype_{kTypeUnknown};
TypeId output_dtype_{kTypeUnknown};
size_t depth_{0};
size_t stride_{0};
size_t axis_{0};
};

MS_REG_CPU_KERNEL(OneHot, KernelAttr(), OneHotCpuKernelMod);
MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kObjectTypeString)
.AddInputAttr(kObjectTypeString)
.AddOutputAttr(kObjectTypeString),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kObjectTypeString)
.AddInputAttr(kObjectTypeString)
.AddOutputAttr(kObjectTypeString),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
OneHotCpuKernelMod);

MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kObjectTypeString)
.AddInputAttr(kObjectTypeString)
.AddOutputAttr(kObjectTypeString),
OneHotCpuKernelMod);
} // namespace kernel
} // namespace mindspore



+ 8
- 4
mindspore/core/ops/one_hot.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -74,13 +74,17 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve

TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = prim->name();
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[kInputIndex0]->BuildType(), {kInt32, kInt64},
op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[kInputIndex0]->BuildType(),
{kUInt8, kInt32, kInt64}, op_name);
(void)CheckAndConvertUtils::CheckTypeValid("depth", input_args[kInputIndex1]->BuildType(),
{kInt8, kInt16, kInt32, kInt64}, op_name);
std::map<std::string, TypePtr> args = {{"on_value", input_args[kInputIndex2]->BuildType()},
{"off_dtype", input_args[kInputIndex3]->BuildType()}};
return CheckAndConvertUtils::CheckTensorTypeSame(args, {kFloat16, kFloat32}, op_name);
return CheckAndConvertUtils::CheckTensorTypeSame(
args,
{kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8, kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32,
kFloat64, kComplex64, kComplex128},
op_name);
}
} // namespace
AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 1
- 1
mindspore/core/ops/one_hot.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.


+ 1
- 0
mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py View File

@@ -115,3 +115,4 @@ from .environ_destroy_all import _environ_destroy_all_aicpu
from .cross import _cross_aicpu
from .cummax import _cummax_aicpu
from .floor_div import _floor_div_aicpu
from .one_hot import _one_hot_aicpu

+ 116
- 0
mindspore/python/mindspore/ops/_op_impl/aicpu/one_hot.py View File

@@ -0,0 +1,116 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""OneHot op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType

one_hot_op_info = AiCPURegOp("OneHot") \
.fusion_type("OPAQUE") \
.input(0, "indices", "required") \
.input(1, "depth", "required") \
.input(2, "on_value", "required") \
.input(3, "off_value", "required") \
.output(0, "output", "required") \
.attr("axis", "int") \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default,
DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U16_Default,
DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U32_Default,
DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U64_Default,
DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I8_Default,
DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I16_Default,
DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.F64_Default,
DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.C64_Default,
DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.C128_Default,
DataType.C128_Default, DataType.C128_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.BOOL_Default,
DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U8_Default,
DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U16_Default,
DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U32_Default,
DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U64_Default,
DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I8_Default,
DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I16_Default,
DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.F64_Default,
DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.C64_Default,
DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.C128_Default,
DataType.C128_Default, DataType.C128_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.BOOL_Default,
DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U8_Default,
DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U16_Default,
DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U32_Default,
DataType.U32_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
DataType.U64_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I8_Default,
DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I16_Default,
DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F64_Default,
DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.C64_Default,
DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.C128_Default,
DataType.C128_Default, DataType.C128_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default,
DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()

@op_info_register(one_hot_op_info)
def _one_hot_aicpu():
"""OneHot aicpu register"""
return

+ 4
- 3
mindspore/python/mindspore/ops/operations/nn_ops.py View File

@@ -3283,10 +3283,11 @@ class OneHot(Primitive):

Inputs:
- **indices** (Tensor) - A tensor of indices. Tensor of shape :math:`(X_0, \ldots, X_n)`.
Data type must be int32 or int64.
Data type must be uint8, int32 or int64.
- **depth** (int) - A scalar defining the depth of the one-hot dimension.
- **on_value** (Tensor) - A value to fill in output when `indices[j] = i`.
With data type of float16 or float32.
Support uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, float64,
bool, complex64, complex128.
- **off_value** (Tensor) - A value to fill in output when `indices[j] != i`.
Has the same data type as `on_value`.

@@ -3295,7 +3296,7 @@ class OneHot(Primitive):

Raises:
TypeError: If `axis` or `depth` is not an int.
TypeError: If dtype of `indices` is neither int32 nor int64.
TypeError: If dtype of `indices` is not uint8, int32 or int64.
TypeError: If `indices`, `on_value` or `off_value` is not a Tensor.
ValueError: If `axis` is not in range [-1, len(indices_shape)].
ValueError: If `depth` is less than 0.


Loading…
Cancel
Save