| @@ -0,0 +1,143 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "device/cpu/kernel_select_cpu.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include "kernel/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace cpu { | |||
| using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; | |||
| using mindspore::kernel::KernelBuildInfo; | |||
| namespace { | |||
| bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) { | |||
| auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (input_node->isa<Parameter>() || input_node->isa<ValueNode>()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vector<size_t> &input_not_cnode_indexes, | |||
| const CNodePtr kernel_node) { | |||
| for (auto &input_index : input_not_cnode_indexes) { | |||
| auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| std::vector<TypeId> output_types; | |||
| output_types.emplace_back(kernel_attr.GetInputAttr(input_index).first); | |||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| MS_EXCEPTION_IF_NULL(builder); | |||
| builder->SetOutputsFormat({kOpFormat_DEFAULT}); | |||
| builder->SetOutputsDeviceType(output_types); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_node.get()); | |||
| } | |||
| } | |||
| void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *input_formats, | |||
| std::vector<TypeId> *input_types, std::vector<size_t> *input_no_cnode_indexes) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| TypeId dtype = kTypeUnknown; | |||
| if (IsInputNotCNode(kernel_node, input_index)) { | |||
| input_no_cnode_indexes->emplace_back(input_index); | |||
| } else { | |||
| dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); | |||
| } | |||
| input_formats->emplace_back(kOpFormat_DEFAULT); | |||
| input_types->emplace_back(dtype); | |||
| } | |||
| } | |||
| void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr, | |||
| std::vector<std::string> *output_formats, std::vector<TypeId> *output_types) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (kernel_attr.GetOutputSize() != output_num) { | |||
| MS_LOG(EXCEPTION) << "Output num is not equal!"; | |||
| } | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| output_formats->emplace_back(kernel_attr.GetOutputAttr(output_index).second); | |||
| auto dtype = kernel_attr.GetOutputAttr(output_index).first; | |||
| output_types->emplace_back(dtype); | |||
| } | |||
| } | |||
| bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector<std::string> &input_formats, | |||
| const std::vector<TypeId> &input_types, | |||
| const std::vector<size_t> &input_not_cnode_indexes) { | |||
| if (kernel_attr.GetInputSize() != input_types.size()) { | |||
| MS_LOG(ERROR) << "Output num is not equal!"; | |||
| return false; | |||
| } | |||
| auto input_num = input_types.size(); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(), | |||
| [i](size_t index) { return index == i; }); | |||
| if (is_not_cnode_idx) { | |||
| continue; | |||
| } | |||
| if (kernel_attr.GetInputAttr(i).first != input_types[i]) { | |||
| MS_LOG(ERROR) << "reg dtype=" << kernel_attr.GetInputAttr(i).first << ", input dtype=" << input_types[i]; | |||
| return false; | |||
| } | |||
| if (kernel_attr.GetInputAttr(i).second != input_formats[i]) { | |||
| MS_LOG(ERROR) << "reg format=" << kernel_attr.GetInputAttr(i).second << ", input format=" << input_formats[i]; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace | |||
| void SetKernelInfo(const CNodePtr &kernel_node) { | |||
| std::vector<std::string> input_formats; | |||
| std::vector<TypeId> input_types; | |||
| std::vector<size_t> input_not_cnode_indexes; | |||
| std::vector<std::string> output_formats; | |||
| std::vector<TypeId> output_types; | |||
| GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes); | |||
| auto kernel_attrs = | |||
| kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node)); | |||
| for (auto &kernel_attr : kernel_attrs) { | |||
| if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { | |||
| GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types); | |||
| UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node); | |||
| for (auto &input_index : input_not_cnode_indexes) { | |||
| input_types[input_index] = kernel_attr.GetInputAttr(input_index).first; | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| MS_EXCEPTION_IF_NULL(builder); | |||
| builder->SetInputsFormat(input_formats); | |||
| builder->SetInputsDeviceType(input_types); | |||
| builder->SetOutputsFormat(output_formats); | |||
| builder->SetOutputsDeviceType(output_types); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); | |||
| } | |||
| } // namespace cpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_DEVICE_CPU_KERNEL_SELECT_CPU_H_ | |||
| #define MINDSPORE_CCSRC_DEVICE_CPU_KERNEL_SELECT_CPU_H_ | |||
| #include <utility> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| #include "ir/dtype/type.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace cpu { | |||
| void SetKernelInfo(const CNodePtr &apply_kernel_ptr); | |||
| class KernelAttr { | |||
| public: | |||
| using DataType = std::pair<TypeId, std::string>; | |||
| KernelAttr() = default; | |||
| ~KernelAttr() = default; | |||
| KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { | |||
| input_type_.emplace_back(ms_type, format); | |||
| return *this; | |||
| } | |||
| KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { | |||
| output_type_.emplace_back(ms_type, format); | |||
| return *this; | |||
| } | |||
| const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; } | |||
| const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; } | |||
| size_t GetInputSize() const { return input_type_.size(); } | |||
| size_t GetOutputSize() const { return output_type_.size(); } | |||
| private: | |||
| std::vector<DataType> input_type_; | |||
| std::vector<DataType> output_type_; | |||
| }; | |||
| } // namespace cpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_DEVICE_CPU_KERNEL_SELECT_CPU_H_ | |||
| @@ -33,7 +33,15 @@ class ApplyMomentumCPUKernel : public MKLCPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| }; | |||
| MS_REG_CPU_KERNEL(ApplyMomentum, ApplyMomentumCPUKernel); | |||
| MS_REG_CPU_KERNEL(ApplyMomentum, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| ApplyMomentumCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -37,7 +37,8 @@ class ArgmaxCPUKernel : public CPUKernel { | |||
| size_t batch_size_{0}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Argmax, ArgmaxCPUKernel); | |||
| MS_REG_CPU_KERNEL(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArgmaxCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -37,7 +37,10 @@ class BiasAddCPUKernel : public CPUKernel { | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<size_t> bias_shape_; | |||
| }; | |||
| MS_REG_CPU_KERNEL(BiasAdd, BiasAddCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| BiasAdd, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| BiasAddCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ | |||
| @@ -36,7 +36,8 @@ class BiasAddGradCPUKernel : public CPUKernel { | |||
| private: | |||
| std::vector<size_t> input_shape_; | |||
| }; | |||
| MS_REG_CPU_KERNEL(BiasAddGrad, BiasAddGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| BiasAddGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ | |||
| @@ -20,29 +20,79 @@ | |||
| #include <iostream> | |||
| #include <string> | |||
| #include "device/kernel_info.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| CPUKernelFactory &CPUKernelFactory::Get() { | |||
| CPUKernelFactory &CPUKernelFactory::GetInstance() { | |||
| static CPUKernelFactory instance; | |||
| return instance; | |||
| } | |||
| void CPUKernelFactory::Register(const std::string &kernel_name, CPUKernelCreator &&kernel_creator) { | |||
| if (kernel_creators_.find(kernel_name) == kernel_creators_.end()) { | |||
| (void)kernel_creators_.emplace(kernel_name, kernel_creator); | |||
| void CPUKernelFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr, | |||
| CPUKernelCreator &&kernel_creator) { | |||
| (void)name_to_attr_creator_[kernel_name].emplace_back(kernel_attr, kernel_creator); | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| MS_LOG(DEBUG) << "CPUKernelFactory register operator: " << kernel_name; | |||
| MS_LOG(DEBUG) << "CPUKernelFactory register operator: " << kernel_name; | |||
| #endif | |||
| } | |||
| } | |||
| std::shared_ptr<CPUKernel> CPUKernelFactory::Create(const std::string &kernel_name) { | |||
| auto iter = kernel_creators_.find(kernel_name); | |||
| if (iter != kernel_creators_.end()) { | |||
| MS_EXCEPTION_IF_NULL(iter->second); | |||
| return (iter->second)(); | |||
| std::shared_ptr<CPUKernel> CPUKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) { | |||
| auto kernel_info = apply_kernel->kernel_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_build_Info); | |||
| std::pair<bool, size_t> ret_pair = CPUKernelAttrCheck(kernel_name, kernel_build_Info); | |||
| if (ret_pair.first) { | |||
| return (name_to_attr_creator_.find(kernel_name)->second)[ret_pair.second].second(); | |||
| } | |||
| return nullptr; | |||
| } | |||
| std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &kernel_name, | |||
| const KernelBuildInfo *kernel_info) { | |||
| auto iter = name_to_attr_creator_.find(kernel_name); | |||
| if (iter == name_to_attr_creator_.end()) { | |||
| MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!"; | |||
| return std::make_pair(false, 0); | |||
| } | |||
| auto creators = iter->second; | |||
| for (size_t index = 0; index < creators.size(); ++index) { | |||
| auto attr_creator = creators[index]; | |||
| for (size_t i = 0; i < kernel_info->GetInputNum(); ++i) { | |||
| if (kernel_info->GetInputDeviceType(i) != attr_creator.first.GetInputAttr(i).first) { | |||
| MS_LOG(WARNING) << "cpu kernel attr check failed. input index: " << i << "."; | |||
| MS_LOG(WARNING) << "kernel info type:" << kernel_info->GetInputDeviceType(i) << ", " | |||
| << "register type:" << attr_creator.first.GetInputAttr(i).first; | |||
| return std::make_pair(false, 0); | |||
| } | |||
| } | |||
| for (size_t i = 0; i < kernel_info->GetOutputNum(); ++i) { | |||
| if (kernel_info->GetOutputDeviceType(i) != attr_creator.first.GetOutputAttr(i).first) { | |||
| MS_LOG(WARNING) << "cpu kernel attr check failed. output index: " << i << "."; | |||
| MS_LOG(WARNING) << "kernel info type:" << kernel_info->GetOutputDeviceType(i) << ", " | |||
| << "register type:" << attr_creator.first.GetOutputAttr(i).first; | |||
| return std::make_pair(false, 0); | |||
| } | |||
| } | |||
| return std::make_pair(true, index); | |||
| } | |||
| return std::make_pair(false, 0); | |||
| } | |||
| std::vector<KernelAttr> CPUKernelFactory::GetSupportedKernelAttrList(const std::string &kernel_name) { | |||
| std::vector<KernelAttr> result; | |||
| auto iter = name_to_attr_creator_.find(kernel_name); | |||
| if (iter == name_to_attr_creator_.end()) { | |||
| MS_LOG(WARNING) << "Not registered CPU kernel: op[" << kernel_name << "]!"; | |||
| return result; | |||
| } | |||
| auto creators = iter->second; | |||
| for (size_t index = 0; index < creators.size(); ++index) { | |||
| auto attr_creator = creators[index]; | |||
| result.push_back(attr_creator.first); | |||
| } | |||
| return result; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -21,35 +21,54 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "common/utils.h" | |||
| #include "kernel/cpu/cpu_kernel.h" | |||
| #include "device/cpu/kernel_select_cpu.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| using mindspore::device::cpu::KernelAttr; | |||
| using CPUKernelCreator = std::function<std::shared_ptr<CPUKernel>()>; | |||
| class CPUKernelFactory { | |||
| public: | |||
| static CPUKernelFactory &Get(); | |||
| void Register(const std::string &kernel_name, CPUKernelCreator &&kernel_creator); | |||
| static CPUKernelFactory &GetInstance(); | |||
| void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator); | |||
| std::shared_ptr<CPUKernel> Create(const std::string &kernel_name); | |||
| std::shared_ptr<CPUKernel> Create(const std::string &kernel_name, const CNodePtr &apply_kernel); | |||
| std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name); | |||
| private: | |||
| CPUKernelFactory() = default; | |||
| ~CPUKernelFactory() = default; | |||
| DISABLE_COPY_AND_ASSIGN(CPUKernelFactory) | |||
| std::map<std::string, CPUKernelCreator> kernel_creators_; | |||
| std::pair<bool, size_t> CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo *kernel_info); | |||
| std::map<std::string, std::vector<std::pair<KernelAttr, CPUKernelCreator>>> name_to_attr_creator_; | |||
| }; | |||
| class CPUKernelRegistrar { | |||
| public: | |||
| CPUKernelRegistrar(const std::string &kernel_name, CPUKernelCreator &&kernel_creator) { | |||
| CPUKernelFactory::Get().Register(kernel_name, std::move(kernel_creator)); | |||
| CPUKernelRegistrar(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator) { | |||
| CPUKernelFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(kernel_creator)); | |||
| } | |||
| ~CPUKernelRegistrar() = default; | |||
| }; | |||
| #define MS_REG_CPU_KERNEL(KERNEL_NAME, KERNEL_CLASS) \ | |||
| static const CPUKernelRegistrar g_cpu_kernel_##KERNEL_NAME##_reg(#KERNEL_NAME, \ | |||
| []() { return std::make_shared<KERNEL_CLASS>(); }); | |||
| #define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) \ | |||
| static_assert(std::is_base_of<CPUKernel, OPCLASS>::value, " must be base of CPUKernel"); \ | |||
| static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_reg(#OPNAME, ATTR, \ | |||
| []() { return std::make_shared<OPCLASS>(); }); | |||
| #define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) \ | |||
| static_assert(std::is_base_of<CPUKernel, OPCLASS<T>>::value, " must be base of CPUKernel"); \ | |||
| static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_reg(#OPNAME, ATTR, \ | |||
| []() { return std::make_shared<OPCLASS<T>>(); }); | |||
| #define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \ | |||
| static_assert(std::is_base_of<CPUKernel, OPCLASS<T, S>>::value, " must be base of CPUKernel"); \ | |||
| static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_##S##_reg( \ | |||
| #OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T, S>>(); }); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -33,7 +33,10 @@ class EqualCountCPUKernel : public CPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| }; | |||
| MS_REG_CPU_KERNEL(EqualCount, EqualCountCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| EqualCount, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| EqualCountCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -33,7 +33,10 @@ class Conv2dCPUKernel : public MKLCPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Conv2D, Conv2dCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Conv2D, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| Conv2dCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -33,7 +33,10 @@ class Conv2dGradFilterCPUKernel : public MKLCPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Conv2DBackpropFilter, Conv2dGradFilterCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Conv2DBackpropFilter, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| Conv2dGradFilterCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -33,7 +33,10 @@ class Conv2dGradInputCPUKernel : public MKLCPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Conv2DBackpropInput, Conv2dGradInputCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Conv2DBackpropInput, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| Conv2dGradInputCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -40,7 +40,10 @@ class MatMulCPUKernel : public MKLCPUKernel { | |||
| dnnl_dim_t dim_k_{0}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(MatMul, MatMulCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| MatMul, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| MatMulCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -33,7 +33,9 @@ class MulCPUKernel : public MKLCPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Mul, MulCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| MulCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -33,7 +33,8 @@ class PoolingCPUKernel : public MKLCPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| }; | |||
| MS_REG_CPU_KERNEL(MaxPool, PoolingCPUKernel); | |||
| MS_REG_CPU_KERNEL(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| PoolingCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -43,7 +43,13 @@ class PoolingGradCPUKernel : public MKLCPUKernel { | |||
| std::vector<size_t> dst_shape_; | |||
| }; | |||
| MS_REG_CPU_KERNEL(MaxPoolGrad, PoolingGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(MaxPoolGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| PoolingGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -33,7 +33,7 @@ class ReluCPUKernel : public MKLCPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| }; | |||
| MS_REG_CPU_KERNEL(ReLU, ReluCPUKernel); | |||
| MS_REG_CPU_KERNEL(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReluCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -33,7 +33,10 @@ class ReluGradCPUKernel : public MKLCPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| }; | |||
| MS_REG_CPU_KERNEL(ReluGrad, ReluGradCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| ReluGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ReluGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -33,7 +33,8 @@ class SoftmaxCPUKernel : public MKLCPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Softmax, SoftmaxCPUKernel); | |||
| MS_REG_CPU_KERNEL(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SoftmaxCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -43,7 +43,10 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel { | |||
| size_t batch_size_{0}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(SparseSoftmaxCrossEntropyWithLogits, SparseSoftmaxCrossEntropyWithLogitsCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| SparseSoftmaxCrossEntropyWithLogits, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||
| SparseSoftmaxCrossEntropyWithLogitsCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -38,7 +38,13 @@ class OneHotCPUKernel : public CPUKernel { | |||
| size_t axis_; | |||
| }; | |||
| MS_REG_CPU_KERNEL(OneHot, OneHotCPUKernel); | |||
| MS_REG_CPU_KERNEL(OneHot, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| OneHotCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -33,9 +33,12 @@ class ReshapeCPUKernel : public CPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Reshape, ReshapeCPUKernel); | |||
| MS_REG_CPU_KERNEL(Flatten, ReshapeCPUKernel); | |||
| MS_REG_CPU_KERNEL(ExpandDims, ReshapeCPUKernel); | |||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ReshapeCPUKernel); | |||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ReshapeCPUKernel); | |||
| MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ReshapeCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -24,6 +24,7 @@ | |||
| #include "device/kernel_runtime.h" | |||
| #include "predict/predict.h" | |||
| #include "kernel/cpu/cpu_kernel_factory.h" | |||
| #include "device/cpu/kernel_select_cpu.h" | |||
| namespace mindspore { | |||
| namespace session { | |||
| @@ -63,43 +64,7 @@ void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) { | |||
| auto &kernel_nodes = kernel_graph->execution_order(); | |||
| for (const auto &kernel_node : kernel_nodes) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| auto input_kernel_node = kernel_node->input(input_index + 1); | |||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||
| if (!input_kernel_node->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| MS_EXCEPTION_IF_NULL(builder); | |||
| std::vector<std::string> output_formats = {kOpFormat_DEFAULT}; | |||
| builder->SetOutputsFormat(output_formats); | |||
| std::vector<TypeId> output_types{kNumberTypeFloat32}; | |||
| builder->SetOutputsDeviceType(output_types); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); | |||
| } | |||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| MS_EXCEPTION_IF_NULL(builder); | |||
| std::vector<std::string> input_formats; | |||
| std::vector<TypeId> input_types; | |||
| std::vector<std::string> output_formats; | |||
| std::vector<TypeId> output_types; | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| input_formats.emplace_back(kOpFormat_DEFAULT); | |||
| input_types.emplace_back(kNumberTypeFloat32); | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| output_formats.emplace_back(kOpFormat_DEFAULT); | |||
| output_types.emplace_back(kNumberTypeFloat32); | |||
| } | |||
| builder->SetInputsFormat(input_formats); | |||
| builder->SetInputsDeviceType(input_types); | |||
| builder->SetOutputsFormat(output_formats); | |||
| builder->SetOutputsDeviceType(output_types); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); | |||
| device::cpu::SetKernelInfo(kernel_node); | |||
| } | |||
| } | |||
| @@ -110,7 +75,8 @@ void CPUSession::BuildKernel(const KernelGraph *kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| MS_LOG(INFO) << "Cpu building operator[" << kernel_name << "]."; | |||
| std::shared_ptr<kernel::CPUKernel> cpu_kernel = kernel::CPUKernelFactory::Get().Create(kernel_name); | |||
| std::shared_ptr<kernel::CPUKernel> cpu_kernel = | |||
| kernel::CPUKernelFactory::GetInstance().Create(kernel_name, kernel_node); | |||
| if (cpu_kernel == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Operator[" << kernel_name << "] is not support."; | |||
| } | |||