Merge pull request !1531 from sunsuodong/multi_inputtags/v0.5.0-beta
| @@ -109,6 +109,21 @@ bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector< | |||
| } | |||
| return true; | |||
| } | |||
| void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) { | |||
| MS_EXCEPTION_IF_NULL(kernel_attr); | |||
| TypeId input_dtype = kernel_attr->GetInputAttr(0).first; | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t i = 1; i < input_num; ++i) { | |||
| kernel_attr->AddInputAttr(input_dtype); | |||
| } | |||
| TypeId output_dtype = kernel_attr->GetOutputAttr(0).first; | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| for (size_t i = 1; i < output_num; ++i) { | |||
| kernel_attr->AddOutputAttr(output_dtype); | |||
| } | |||
| } | |||
| } // namespace | |||
| void SetKernelInfo(const CNodePtr &kernel_node) { | |||
| @@ -125,12 +140,16 @@ void SetKernelInfo(const CNodePtr &kernel_node) { | |||
| kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node)); | |||
| for (size_t index = 0; index < kernel_attrs.size(); ++index) { | |||
| if (IsInputFormatDtypeMatched(kernel_attrs[index], input_formats, input_types, input_not_cnode_indexes)) { | |||
| auto kernel_attr = kernel_attrs[index]; | |||
| if (kernel_attr.GetAllSame()) { | |||
| ExpandKernelAttr(kernel_node, &kernel_attr); | |||
| } | |||
| if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { | |||
| MS_LOG(INFO) << "Input format and dtype is matched, index: " << index; | |||
| GetOutputFormatsAndDtypes(kernel_node, kernel_attrs[index], &output_formats, &output_types); | |||
| UpdatePrevNotCNodeFormatDtype(kernel_attrs[index], input_not_cnode_indexes, kernel_node); | |||
| GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types); | |||
| UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node); | |||
| for (auto &input_index : input_not_cnode_indexes) { | |||
| input_types[input_index] = kernel_attrs[index].GetInputAttr(input_index).first; | |||
| input_types[input_index] = kernel_attr.GetInputAttr(input_index).first; | |||
| } | |||
| break; | |||
| } | |||
| @@ -46,8 +46,14 @@ class KernelAttr { | |||
| return *this; | |||
| } | |||
| KernelAttr &SetAllSameAttr(bool all_same) { | |||
| all_same_ = all_same; | |||
| return *this; | |||
| } | |||
| const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; } | |||
| const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; } | |||
| bool GetAllSame() const { return all_same_; } | |||
| size_t GetInputSize() const { return input_type_.size(); } | |||
| size_t GetOutputSize() const { return output_type_.size(); } | |||
| @@ -55,6 +61,7 @@ class KernelAttr { | |||
| private: | |||
| std::vector<DataType> input_type_; | |||
| std::vector<DataType> output_type_; | |||
| bool all_same_; | |||
| }; | |||
| } // namespace cpu | |||
| } // namespace device | |||
| @@ -39,16 +39,8 @@ class AddNCPUKernel : public CPUKernel { | |||
| std::vector<size_t> output_shape_; | |||
| }; | |||
| MS_REG_CPU_KERNEL( | |||
| AddN, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| AddNCPUKernel); | |||
| MS_REG_CPU_KERNEL(AddN, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| AddNCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -41,10 +41,9 @@ class ConcatCPUKernel : public CPUKernel { | |||
| std::vector<size_t> output_shape_; | |||
| }; | |||
| MS_REG_CPU_KERNEL( | |||
| Concat, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ConcatCPUKernel); | |||
| MS_REG_CPU_KERNEL(Concat, | |||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ConcatCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -59,28 +59,27 @@ std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string & | |||
| auto creators = iter->second; | |||
| for (size_t index = 0; index < creators.size(); ++index) { | |||
| auto attr_creator = creators[index]; | |||
| if (CPUKernelSingleAttrCheck(attr_creator, kernel_info)) { | |||
| if (CPUKernelSingleAttrCheck(attr_creator.first, kernel_info)) { | |||
| return std::make_pair(true, index); | |||
| } | |||
| } | |||
| return std::make_pair(false, 0); | |||
| } | |||
| bool CPUKernelFactory::CPUKernelSingleAttrCheck(const std::pair<KernelAttr, CPUKernelCreator> &attr_creator, | |||
| const KernelBuildInfo &kernel_info) { | |||
| bool CPUKernelFactory::CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info) { | |||
| for (size_t i = 0; i < kernel_info.GetInputNum(); ++i) { | |||
| if (kernel_info.GetInputDeviceType(i) != attr_creator.first.GetInputAttr(i).first) { | |||
| MS_LOG(DEBUG) << "cpu kernel attr check failed. input index: " << i << "."; | |||
| MS_LOG(DEBUG) << "kernel info type:" << kernel_info.GetInputDeviceType(i) << ", " | |||
| << "register type:" << attr_creator.first.GetInputAttr(i).first; | |||
| auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetInputAttr(0).first : kernel_attr.GetInputAttr(i).first; | |||
| if (kernel_info.GetInputDeviceType(i) != dtype) { | |||
| MS_LOG(DEBUG) << "input index:" << i << ", kernel info type:" << kernel_info.GetInputDeviceType(i) | |||
| << ", register type:" << dtype; | |||
| return false; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < kernel_info.GetOutputNum(); ++i) { | |||
| if (kernel_info.GetOutputDeviceType(i) != attr_creator.first.GetOutputAttr(i).first) { | |||
| MS_LOG(DEBUG) << "cpu kernel attr check failed. output index: " << i << "."; | |||
| MS_LOG(DEBUG) << "kernel info type:" << kernel_info.GetOutputDeviceType(i) << ", " | |||
| << "register type:" << attr_creator.first.GetOutputAttr(i).first; | |||
| auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetOutputAttr(0).first : kernel_attr.GetOutputAttr(i).first; | |||
| if (kernel_info.GetOutputDeviceType(i) != dtype) { | |||
| MS_LOG(DEBUG) << "output index:" << i << ", kernel info type:" << kernel_info.GetOutputDeviceType(i) | |||
| << ", register type:" << dtype; | |||
| return false; | |||
| } | |||
| } | |||
| @@ -35,7 +35,6 @@ class CPUKernelFactory { | |||
| public: | |||
| static CPUKernelFactory &GetInstance(); | |||
| void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator); | |||
| std::shared_ptr<CPUKernel> Create(const std::string &kernel_name); | |||
| std::shared_ptr<CPUKernel> Create(const std::string &kernel_name, const CNodePtr &apply_kernel); | |||
| std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name); | |||
| @@ -44,8 +43,7 @@ class CPUKernelFactory { | |||
| ~CPUKernelFactory() = default; | |||
| DISABLE_COPY_AND_ASSIGN(CPUKernelFactory) | |||
| std::pair<bool, size_t> CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo &kernel_info); | |||
| bool CPUKernelSingleAttrCheck(const std::pair<KernelAttr, CPUKernelCreator> &attr_creator, | |||
| const KernelBuildInfo &kernel_info); | |||
| bool CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info); | |||
| std::map<std::string, std::vector<std::pair<KernelAttr, CPUKernelCreator>>> name_to_attr_creator_; | |||
| }; | |||
| @@ -71,13 +71,13 @@ def test_in2_axis1(): | |||
| assert np.all(diff < error) | |||
| assert np.all(-diff < error) | |||
| class Concat_Axis2(nn.Cell): | |||
| class Concat_in3_Axis2(nn.Cell): | |||
| def __init__(self): | |||
| super(Concat_Axis2, self).__init__() | |||
| super(Concat_in3_Axis2, self).__init__() | |||
| self.cat = P.Concat(axis=-1) | |||
| def construct(self, x1, x2): | |||
| return self.cat((x1, x2)) | |||
| def construct(self, x1, x2, x3): | |||
| return self.cat((x1, x2, x3)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @@ -86,10 +86,10 @@ def test_in3_axis2(): | |||
| x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1), mstype.float32) | |||
| x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32) | |||
| x3 = Tensor(np.arange(2 * 2 * 3).reshape(2, 2, 3), mstype.float32) | |||
| cat = Concat_Axis2() | |||
| output_ms = cat(x1, x2) | |||
| cat = Concat_in3_Axis2() | |||
| output_ms = cat(x1, x2, x3) | |||
| print("output:\n", output_ms) | |||
| output_np = np.concatenate((x1.asnumpy(), x2.asnumpy()), axis=-1) | |||
| output_np = np.concatenate((x1.asnumpy(), x2.asnumpy(), x3.asnumpy()), axis=-1) | |||
| error = np.ones(shape=output_np.shape) * 10e-6 | |||
| diff = output_ms.asnumpy() - output_np | |||