| @@ -28,7 +28,8 @@ namespace kernel { | |||
| * ... | |||
| * } | |||
| */ | |||
| std::map<string, std::vector<std::pair<string, size_t>>> AicpuOpAttrToInputMap = {}; | |||
| std::map<string, std::vector<std::pair<string, size_t>>> AicpuOpAttrToInputMap = { | |||
| {prim::kPrimOneHot->name(), {{"depth", 1}}}}; | |||
| bool GetAicpuOpAttrToInputInfo(const CNodePtr &kernel_node, std::vector<std::pair<string, size_t>> *info) { | |||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| @@ -15,6 +15,8 @@ | |||
| */ | |||
| #include "plugin/device/cpu/kernel/one_hot_cpu_kernel.h" | |||
| #include <string> | |||
| #include <complex> | |||
| #include "plugin/device/cpu/hal/device/cpu_device_address.h" | |||
| namespace mindspore { | |||
| @@ -22,24 +24,55 @@ namespace kernel { | |||
| namespace { | |||
| constexpr size_t kOneHotInputsNum = 3; | |||
| constexpr size_t kOneHotOutputsNum = 1; | |||
| #define INPUT_COMPUTE_CASE(DTYPE, TYPE, ODTYPE, INPUTS, OUTPUTS) \ | |||
| case (DTYPE): { \ | |||
| switch (ODTYPE) { \ | |||
| INPUT_COMPUTE_CASE_INT(DTYPE, TYPE, ODTYPE, INPUTS, OUTPUTS) \ | |||
| INPUT_COMPUTE_CASE_FLOAT(DTYPE, TYPE, ODTYPE, INPUTS, OUTPUTS) \ | |||
| default: \ | |||
| MS_LOG(EXCEPTION) << " For OneHot the dtype of output not support."; \ | |||
| } \ | |||
| break; \ | |||
| } | |||
| #define INPUT_COMPUTE_CASE_INT(DTYPE, TYPE, ODTYPE, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeInt8, int8_t, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeInt16, int16_t, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeInt32, int32_t, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeInt64, int64_t, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeUInt8, uint8_t, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeUInt16, uint16_t, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeUInt32, uint32_t, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeUInt64, uint64_t, INPUTS, OUTPUTS) | |||
| #define INPUT_COMPUTE_CASE_FLOAT(DTYPE, TYPE, ODTYPE, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeComplex64, std::complex<float>, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeComplex128, std::complex<double>, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeFloat64, double, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeFloat32, float_t, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeFloat16, float16, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kNumberTypeBool, bool, INPUTS, OUTPUTS) \ | |||
| OUTPUT_COMPUTE_CASE(TYPE, kObjectTypeString, std::string, INPUTS, OUTPUTS) | |||
| #define OUTPUT_COMPUTE_CASE(TYPE, ODTYPE, OTYPE, INPUTS, OUTPUTS) \ | |||
| case (ODTYPE): { \ | |||
| LaunchKernel<TYPE, OTYPE>(INPUTS, OUTPUTS); \ | |||
| break; \ | |||
| } | |||
| } // namespace | |||
| void OneHotCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); | |||
| input_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| output_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| if (output_shape.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ | |||
| << "', the dimension of output should be greater than or equal to 2, but got " | |||
| << output_shape.size() << "."; | |||
| } | |||
| int64_t axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS); | |||
| if (axis != -1 && LongToSize(axis) >= output_shape.size()) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ | |||
| << "', the 'axis' should be -1, or an int which is less than the dimension of output, but got " | |||
| << axis << ", got the dimension of output " << output_shape.size(); | |||
| } | |||
| if (axis == -1) { | |||
| axis_ = output_shape.size() - 1; | |||
| } else { | |||
| @@ -56,12 +89,24 @@ bool OneHotCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, c | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| CHECK_KERNEL_INPUTS_NUM(inputs.size(), kOneHotInputsNum, kernel_name_); | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOneHotOutputsNum, kernel_name_); | |||
| const auto *indices = reinterpret_cast<int *>(inputs[0]->addr); | |||
| auto on_value = reinterpret_cast<float *>(inputs[1]->addr)[0]; | |||
| auto off_value = reinterpret_cast<float *>(inputs[2]->addr)[0]; | |||
| auto *output = reinterpret_cast<float *>(outputs[0]->addr); | |||
| size_t elem_num = inputs[0]->size / sizeof(int); | |||
| switch (input_dtype_) { | |||
| INPUT_COMPUTE_CASE(kNumberTypeUInt8, uint8_t, output_dtype_, inputs, outputs); | |||
| INPUT_COMPUTE_CASE(kNumberTypeInt32, int32_t, output_dtype_, inputs, outputs); | |||
| INPUT_COMPUTE_CASE(kNumberTypeInt64, int64_t, output_dtype_, inputs, outputs); | |||
| default: | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of input 'x' " | |||
| << TypeIdToType(input_dtype_)->ToString() << " not support."; | |||
| } | |||
| return true; | |||
| } | |||
| template <typename ID, typename OD> | |||
| void OneHotCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||
| const auto *indices = reinterpret_cast<ID *>(inputs[0]->addr); | |||
| auto on_value = reinterpret_cast<OD *>(inputs[1]->addr)[0]; | |||
| auto off_value = reinterpret_cast<OD *>(inputs[2]->addr)[0]; | |||
| auto *output = reinterpret_cast<OD *>(outputs[0]->addr); | |||
| size_t elem_num = inputs[0]->size / sizeof(ID); | |||
| auto task = [this, &indices, &on_value, &off_value, &output](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| size_t stride_num = i / stride_; | |||
| @@ -78,8 +123,6 @@ bool OneHotCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, c | |||
| } | |||
| }; | |||
| ParallelLaunchAutoSearch(task, elem_num, this, ¶llel_search_info_); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -35,12 +35,375 @@ class OneHotCpuKernelMod : public NativeCpuKernelMod { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| private: | |||
| template <typename ID, typename OD> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| TypeId input_dtype_{kTypeUnknown}; | |||
| TypeId output_dtype_{kTypeUnknown}; | |||
| size_t depth_{0}; | |||
| size_t stride_{0}; | |||
| size_t axis_{0}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(OneHot, KernelAttr(), OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddOutputAttr(kNumberTypeUInt8), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeUInt16) | |||
| .AddInputAttr(kNumberTypeUInt16) | |||
| .AddOutputAttr(kNumberTypeUInt16), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeUInt32) | |||
| .AddInputAttr(kNumberTypeUInt32) | |||
| .AddOutputAttr(kNumberTypeUInt32), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeUInt64) | |||
| .AddInputAttr(kNumberTypeUInt64) | |||
| .AddOutputAttr(kNumberTypeUInt64), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeInt8) | |||
| .AddInputAttr(kNumberTypeInt8) | |||
| .AddOutputAttr(kNumberTypeInt8), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeInt16) | |||
| .AddInputAttr(kNumberTypeInt16) | |||
| .AddOutputAttr(kNumberTypeInt16), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeInt64), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddOutputAttr(kNumberTypeFloat64), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddOutputAttr(kNumberTypeBool), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeComplex64) | |||
| .AddInputAttr(kNumberTypeComplex64) | |||
| .AddOutputAttr(kNumberTypeComplex64), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeComplex128) | |||
| .AddInputAttr(kNumberTypeComplex128) | |||
| .AddOutputAttr(kNumberTypeComplex128), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kObjectTypeString) | |||
| .AddInputAttr(kObjectTypeString) | |||
| .AddOutputAttr(kObjectTypeString), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddOutputAttr(kNumberTypeUInt8), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeUInt16) | |||
| .AddInputAttr(kNumberTypeUInt16) | |||
| .AddOutputAttr(kNumberTypeUInt16), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeUInt32) | |||
| .AddInputAttr(kNumberTypeUInt32) | |||
| .AddOutputAttr(kNumberTypeUInt32), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeUInt64) | |||
| .AddInputAttr(kNumberTypeUInt64) | |||
| .AddOutputAttr(kNumberTypeUInt64), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt8) | |||
| .AddInputAttr(kNumberTypeInt8) | |||
| .AddOutputAttr(kNumberTypeInt8), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt16) | |||
| .AddInputAttr(kNumberTypeInt16) | |||
| .AddOutputAttr(kNumberTypeInt16), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeInt64), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddOutputAttr(kNumberTypeFloat64), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddOutputAttr(kNumberTypeBool), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeComplex64) | |||
| .AddInputAttr(kNumberTypeComplex64) | |||
| .AddOutputAttr(kNumberTypeComplex64), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeComplex128) | |||
| .AddInputAttr(kNumberTypeComplex128) | |||
| .AddOutputAttr(kNumberTypeComplex128), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kObjectTypeString) | |||
| .AddInputAttr(kObjectTypeString) | |||
| .AddOutputAttr(kObjectTypeString), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddInputAttr(kNumberTypeUInt8) | |||
| .AddOutputAttr(kNumberTypeUInt8), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeUInt16) | |||
| .AddInputAttr(kNumberTypeUInt16) | |||
| .AddOutputAttr(kNumberTypeUInt16), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeUInt32) | |||
| .AddInputAttr(kNumberTypeUInt32) | |||
| .AddOutputAttr(kNumberTypeUInt32), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeUInt64) | |||
| .AddInputAttr(kNumberTypeUInt64) | |||
| .AddOutputAttr(kNumberTypeUInt64), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt8) | |||
| .AddInputAttr(kNumberTypeInt8) | |||
| .AddOutputAttr(kNumberTypeInt8), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt16) | |||
| .AddInputAttr(kNumberTypeInt16) | |||
| .AddOutputAttr(kNumberTypeInt16), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeInt64), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddOutputAttr(kNumberTypeFloat64), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddOutputAttr(kNumberTypeBool), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeComplex64) | |||
| .AddInputAttr(kNumberTypeComplex64) | |||
| .AddOutputAttr(kNumberTypeComplex64), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeComplex128) | |||
| .AddInputAttr(kNumberTypeComplex128) | |||
| .AddOutputAttr(kNumberTypeComplex128), | |||
| OneHotCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kObjectTypeString) | |||
| .AddInputAttr(kObjectTypeString) | |||
| .AddOutputAttr(kObjectTypeString), | |||
| OneHotCpuKernelMod); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -74,13 +74,17 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve | |||
| TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| auto op_name = prim->name(); | |||
| (void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[kInputIndex0]->BuildType(), {kInt32, kInt64}, | |||
| op_name); | |||
| (void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[kInputIndex0]->BuildType(), | |||
| {kUInt8, kInt32, kInt64}, op_name); | |||
| (void)CheckAndConvertUtils::CheckTypeValid("depth", input_args[kInputIndex1]->BuildType(), | |||
| {kInt8, kInt16, kInt32, kInt64}, op_name); | |||
| std::map<std::string, TypePtr> args = {{"on_value", input_args[kInputIndex2]->BuildType()}, | |||
| {"off_dtype", input_args[kInputIndex3]->BuildType()}}; | |||
| return CheckAndConvertUtils::CheckTensorTypeSame(args, {kFloat16, kFloat32}, op_name); | |||
| return CheckAndConvertUtils::CheckTensorTypeSame( | |||
| args, | |||
| {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8, kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, | |||
| kFloat64, kComplex64, kComplex128}, | |||
| op_name); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -115,3 +115,4 @@ from .environ_destroy_all import _environ_destroy_all_aicpu | |||
| from .cross import _cross_aicpu | |||
| from .cummax import _cummax_aicpu | |||
| from .floor_div import _floor_div_aicpu | |||
| from .one_hot import _one_hot_aicpu | |||
| @@ -0,0 +1,116 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """OneHot op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| one_hot_op_info = AiCPURegOp("OneHot") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "indices", "required") \ | |||
| .input(1, "depth", "required") \ | |||
| .input(2, "on_value", "required") \ | |||
| .input(3, "off_value", "required") \ | |||
| .output(0, "output", "required") \ | |||
| .attr("axis", "int") \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default, | |||
| DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U16_Default, | |||
| DataType.U16_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U32_Default, | |||
| DataType.U32_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U64_Default, | |||
| DataType.U64_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I8_Default, | |||
| DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I16_Default, | |||
| DataType.I16_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I32_Default, | |||
| DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I64_Default, | |||
| DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.F64_Default, | |||
| DataType.F64_Default, DataType.F64_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.C64_Default, | |||
| DataType.C64_Default, DataType.C64_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.C128_Default, | |||
| DataType.C128_Default, DataType.C128_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.BOOL_Default, | |||
| DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U8_Default, | |||
| DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U16_Default, | |||
| DataType.U16_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U32_Default, | |||
| DataType.U32_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U64_Default, | |||
| DataType.U64_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I8_Default, | |||
| DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I16_Default, | |||
| DataType.I16_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, | |||
| DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default, | |||
| DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.F64_Default, | |||
| DataType.F64_Default, DataType.F64_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.C64_Default, | |||
| DataType.C64_Default, DataType.C64_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.C128_Default, | |||
| DataType.C128_Default, DataType.C128_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.BOOL_Default, | |||
| DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U8_Default, | |||
| DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U16_Default, | |||
| DataType.U16_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U32_Default, | |||
| DataType.U32_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U64_Default, | |||
| DataType.U64_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I8_Default, | |||
| DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I16_Default, | |||
| DataType.I16_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default, | |||
| DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, | |||
| DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F64_Default, | |||
| DataType.F64_Default, DataType.F64_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.C64_Default, | |||
| DataType.C64_Default, DataType.C64_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.C128_Default, | |||
| DataType.C128_Default, DataType.C128_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default, | |||
| DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(one_hot_op_info) | |||
| def _one_hot_aicpu(): | |||
| """OneHot aicpu register""" | |||
| return | |||
| @@ -3283,10 +3283,11 @@ class OneHot(Primitive): | |||
| Inputs: | |||
| - **indices** (Tensor) - A tensor of indices. Tensor of shape :math:`(X_0, \ldots, X_n)`. | |||
| Data type must be int32 or int64. | |||
| Data type must be uint8, int32 or int64. | |||
| - **depth** (int) - A scalar defining the depth of the one-hot dimension. | |||
| - **on_value** (Tensor) - A value to fill in output when `indices[j] = i`. | |||
| With data type of float16 or float32. | |||
| Support uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, float64, | |||
| bool, complex64, complex128. | |||
| - **off_value** (Tensor) - A value to fill in output when `indices[j] != i`. | |||
| Has the same data type as `on_value`. | |||
| @@ -3295,7 +3296,7 @@ class OneHot(Primitive): | |||
| Raises: | |||
| TypeError: If `axis` or `depth` is not an int. | |||
| TypeError: If dtype of `indices` is neither int32 nor int64. | |||
| TypeError: If dtype of `indices` is not uint8, int32 or int64. | |||
| TypeError: If `indices`, `on_value` or `off_value` is not a Tensor. | |||
| ValueError: If `axis` is not in range [-1, len(indices_shape)]. | |||
| ValueError: If `depth` is less than 0. | |||