| @@ -45,64 +45,6 @@ enum MatchCountPriority : int { | |||
| const size_t kMaxCount = 0xffffffff; | |||
| const int kUnSupportMixedDataTypeIndex = -1; | |||
| const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, | |||
| kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, | |||
| kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, | |||
| kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04}; | |||
| bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | |||
| // if format is default, it remarkes support all format | |||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||
| MS_LOG(EXCEPTION) << "got the unknown format " << format; | |||
| } | |||
| if (format == kOpFormat_DEFAULT) { | |||
| return true; | |||
| } | |||
| // if shape size is 0, the shape will be a scalar | |||
| if (shape.empty()) { | |||
| return true; | |||
| } | |||
| if (shape.size() > kShapeSupportFormatMap.size()) { | |||
| return false; | |||
| } | |||
| if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) { | |||
| return true; | |||
| } | |||
| return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end()); | |||
| } | |||
| bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool { | |||
| if (!IsShapeMatchFormat(shape, format)) { | |||
| return false; | |||
| } | |||
| for (auto shape_value : shape) { | |||
| if (shape_value == 0) { | |||
| MS_LOG(EXCEPTION) << "dimension size of the tensor shape should be a positive integer, but got " << shape_value; | |||
| } | |||
| } | |||
| return true; | |||
| }; | |||
| if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { | |||
| return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && | |||
| AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); | |||
| } | |||
| for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); | |||
| if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) { | |||
| return false; | |||
| } | |||
| } | |||
| for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); | |||
| if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| // Check input data type | |||
| @@ -459,6 +401,29 @@ int PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index, | |||
| // raise precision | |||
| int selected_index = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, | |||
| kernel_support_datatype, kernel_match_datatype_idx); | |||
| if (selected_index != -1) { | |||
| int max_match = 0; | |||
| auto iter = kernel_match_datatype_idx->begin(); | |||
| int match_count = 0; | |||
| while (iter != kernel_match_datatype_idx->end()) { | |||
| auto kernel_datatypes = kernel_support_datatype.find(iter->first); | |||
| if (kernel_datatypes == kernel_support_datatype.end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find kernel index" << iter->first << "'s datatype."; | |||
| } | |||
| if (kernel_datatypes->second.size() < node_mix_precision_datatype.size()) { | |||
| MS_LOG(EXCEPTION) << "Kernel datatype size is not equal to node datatype size!"; | |||
| } | |||
| for (size_t i = 0; i < node_mix_precision_datatype.size(); ++i) { | |||
| if (node_mix_precision_datatype[i] == kernel_datatypes->second[i]) { | |||
| ++match_count; | |||
| } | |||
| } | |||
| if (match_count > max_match) { | |||
| selected_index = SizeToInt(iter->first); | |||
| } | |||
| ++iter; | |||
| } | |||
| } | |||
| if (selected_index == -1 && context_ptr->enable_reduce_precision()) { | |||
| selected_index = | |||
| RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, | |||
| @@ -507,9 +472,6 @@ void SelectKernelInfo(const CNodePtr &kernel_node) { | |||
| kernel::KernelQuery(kernel_node, &kernel_info_list); | |||
| std::vector<int> most_match_counts = {-1, -1, -1, -1}; | |||
| int selected_index = -1; | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool auto_mixed_precision = context_ptr->auto_mixed_precision_flag(); | |||
| std::unordered_map<size_t, std::vector<int>> kernel_match_datatype_idx; | |||
| std::unordered_map<size_t, std::vector<TypeId>> kernel_support_datatype; | |||
| std::vector<int> node_mix_precision_datatype_index; | |||
| @@ -517,16 +479,13 @@ void SelectKernelInfo(const CNodePtr &kernel_node) { | |||
| for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { | |||
| std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0}; | |||
| auto kernel_build_info = *(kernel_info_list[info_index]); | |||
| if (!IsValidKernelInfo(kernel_node, kernel_build_info)) { | |||
| continue; | |||
| } | |||
| std::vector<int> support_indexes; | |||
| std::vector<TypeId> support_datatypes; | |||
| AddNodeAndKernelDataType(kernel_node, kernel_build_info, &support_indexes, &node_mix_precision_datatype, | |||
| &support_datatypes, &node_mix_precision_datatype_index); | |||
| kernel_match_datatype_idx[info_index] = support_indexes; | |||
| kernel_support_datatype[info_index] = support_datatypes; | |||
| if (!auto_mixed_precision && !MatchInferOutputDataType(kernel_node, kernel_build_info)) { | |||
| if (!MatchInferOutputDataType(kernel_node, kernel_build_info)) { | |||
| continue; | |||
| } | |||
| std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index]; | |||
| @@ -19,6 +19,7 @@ | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <set> | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "kernel/oplib/oplib.h" | |||
| @@ -510,6 +511,64 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn | |||
| return true; | |||
| } | |||
| bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | |||
| const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, | |||
| kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, | |||
| kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, | |||
| kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04}; | |||
| // if format is default, it remarkes support all format | |||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||
| MS_LOG(EXCEPTION) << "Got the unknown format " << format; | |||
| } | |||
| if (format == kOpFormat_DEFAULT) { | |||
| return true; | |||
| } | |||
| // if shape size is 0, the shape will be a scalar | |||
| if (shape.empty()) { | |||
| return true; | |||
| } | |||
| if (shape.size() > kShapeSupportFormatMap.size()) { | |||
| return false; | |||
| } | |||
| if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) { | |||
| return true; | |||
| } | |||
| return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end()); | |||
| } | |||
| bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool { | |||
| if (!IsShapeMatchFormat(shape, format)) { | |||
| return false; | |||
| } | |||
| for (auto shape_value : shape) { | |||
| if (shape_value == 0) { | |||
| MS_LOG(EXCEPTION) << "Dimension size of the tensor shape should be a positive integer, but got " << shape_value; | |||
| } | |||
| } | |||
| return true; | |||
| }; | |||
| for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); | |||
| if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) { | |||
| return false; | |||
| } | |||
| } | |||
| for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); | |||
| if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) { | |||
| return false; | |||
| } | |||
| } | |||
| if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { | |||
| return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && | |||
| AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); | |||
| } | |||
| return true; | |||
| } | |||
| void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| @@ -534,7 +593,7 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke | |||
| if (context_ptr->execution_mode() == kPynativeMode) { | |||
| kernel_info_list->push_back(parse_info); | |||
| } else { | |||
| if (CheckSupported(kernel_node, parse_info)) { | |||
| if (IsValidKernelInfo(kernel_node, *(parse_info)) && CheckSupported(kernel_node, parse_info)) { | |||
| kernel_info_list->push_back(parse_info); | |||
| } else { | |||
| MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info."; | |||
| @@ -37,6 +37,7 @@ | |||
| #include "pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h" | |||
| #include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" | |||
| #include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" | |||
| #include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h" | |||
| #include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h" | |||
| #include "pre_activate/ascend/ir_fusion/transdata_split.h" | |||
| #include "pre_activate/ascend/ir_fission/topk_split.h" | |||
| @@ -243,6 +244,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||
| auto other_pm = std::make_shared<PassManager>("other_pm"); | |||
| other_pm->AddPass(std::make_shared<AllReduceFusion>()); | |||
| other_pm->AddPass(std::make_shared<ParameterTransOpFusion>()); | |||
| other_pm->AddPass(std::make_shared<BufferFusion>()); | |||
| other_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| other_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||
| @@ -0,0 +1,120 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h" | |||
| #include <memory> | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "utils/utils.h" | |||
| #include "operator/ops.h" | |||
| #include "device/kernel_info.h" | |||
| #include "pre_activate/common/helper.h" | |||
| #include "pre_activate/common/optimizer.h" | |||
| #include "pre_activate/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool first_flag, | |||
| std::vector<CNodePtr> *trans_road) { | |||
| if (node == nullptr) { | |||
| MS_LOG(ERROR) << "nullptr"; | |||
| return nullptr; | |||
| } | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||
| auto manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| return nullptr; | |||
| } | |||
| if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() || | |||
| op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) { | |||
| auto users = manager->node_users()[node]; | |||
| if (users.size() > 1 && !first_flag) { | |||
| return nullptr; | |||
| } | |||
| trans_road->push_back(cnode); | |||
| first_flag = false; | |||
| auto next_node = AnfAlgo::GetInputNode(cnode, 0); | |||
| if (next_node->isa<Parameter>() || next_node->isa<ValueNode>()) { | |||
| return next_node; | |||
| } | |||
| return ParamTransRoad(func_graph, next_node, first_flag, trans_road); | |||
| } | |||
| } else if (node->isa<Parameter>() || node->isa<ValueNode>()) { | |||
| return node; | |||
| } | |||
| return nullptr; | |||
| } | |||
| bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Func graph is nullptr"; | |||
| return false; | |||
| } | |||
| auto manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| return false; | |||
| } | |||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return()); | |||
| bool changed = false; | |||
| for (auto node : node_list) { | |||
| if (node == nullptr || !node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto node_name = AnfAlgo::GetCNodeName(cnode); | |||
| if (node_name == prim::kPrimCast->name() || node_name == prim::kPrimTranspose->name() || | |||
| node_name == prim::kPrimReshape->name() || node_name == kTransDataOpName) { | |||
| MS_LOG(DEBUG) << "Skip trans op"; | |||
| continue; | |||
| } | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { | |||
| std::vector<CNodePtr> trans_road; | |||
| bool first_flag = true; | |||
| auto final_node = ParamTransRoad(func_graph, AnfAlgo::GetInputNode(cnode, input_index), first_flag, &trans_road); | |||
| if (final_node != nullptr && trans_road.size() == 3 && AnfAlgo::GetCNodeName(trans_road[0]) == kTransDataOpName && | |||
| AnfAlgo::GetCNodeName(trans_road[1]) == prim::kPrimCast->name() && | |||
| AnfAlgo::GetCNodeName(trans_road[2]) == kTransDataOpName) { | |||
| auto cur_transop = trans_road[0]; | |||
| auto format = AnfAlgo::GetOutputFormat(cur_transop, 0); | |||
| auto dtype = AnfAlgo::GetOutputDeviceDataType(cur_transop, 0); | |||
| auto param_format = AnfAlgo::GetOutputFormat(final_node, 0); | |||
| auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); | |||
| auto cast = trans_road[1]; | |||
| auto cast_format = AnfAlgo::GetOutputFormat(cast, 0); | |||
| auto cast_build_info = cast->kernel_info()->select_kernel_build_info(); | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| builder.SetOutputsFormat({format}); | |||
| builder.SetInputsFormat({format}); | |||
| builder.SetInputsDeviceType({param_dtype}); | |||
| builder.SetOutputsDeviceType({dtype}); | |||
| builder.SetKernelType(cast_build_info->kernel_type()); | |||
| builder.SetFusionType(cast_build_info->fusion_type()); | |||
| builder.SetProcessor(cast_build_info->processor()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); | |||
| if (param_format == format && param_dtype != dtype) { | |||
| manager->Replace(trans_road[2], final_node); | |||
| manager->Replace(cur_transop, cast); | |||
| } | |||
| changed = true; | |||
| } | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "pre_activate/common/pass.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class ParameterTransOpFusion : public Pass { | |||
| public: | |||
| explicit ParameterTransOpFusion(size_t groups = 1) : Pass("Parameter_and_transop_fusion"), groups_(groups) {} | |||
| ~ParameterTransOpFusion() override = default; | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| size_t groups_ = 1; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif | |||
| @@ -44,6 +44,12 @@ cast_op_info = TBERegOp("Cast") \ | |||
| .dtype_format(DataType.F16_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default) \ | |||
| .get_op_info() | |||