| @@ -45,64 +45,6 @@ enum MatchCountPriority : int { | |||||
| const size_t kMaxCount = 0xffffffff; | const size_t kMaxCount = 0xffffffff; | ||||
| const int kUnSupportMixedDataTypeIndex = -1; | const int kUnSupportMixedDataTypeIndex = -1; | ||||
| const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, | |||||
| kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, | |||||
| kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, | |||||
| kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04}; | |||||
| bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | |||||
| // if format is default, it remarkes support all format | |||||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||||
| MS_LOG(EXCEPTION) << "got the unknown format " << format; | |||||
| } | |||||
| if (format == kOpFormat_DEFAULT) { | |||||
| return true; | |||||
| } | |||||
| // if shape size is 0, the shape will be a scalar | |||||
| if (shape.empty()) { | |||||
| return true; | |||||
| } | |||||
| if (shape.size() > kShapeSupportFormatMap.size()) { | |||||
| return false; | |||||
| } | |||||
| if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) { | |||||
| return true; | |||||
| } | |||||
| return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end()); | |||||
| } | |||||
| bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool { | |||||
| if (!IsShapeMatchFormat(shape, format)) { | |||||
| return false; | |||||
| } | |||||
| for (auto shape_value : shape) { | |||||
| if (shape_value == 0) { | |||||
| MS_LOG(EXCEPTION) << "dimension size of the tensor shape should be a positive integer, but got " << shape_value; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| }; | |||||
| if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { | |||||
| return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && | |||||
| AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); | |||||
| } | |||||
| for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); | |||||
| if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); | |||||
| if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { | bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| // Check input data type | // Check input data type | ||||
| @@ -459,6 +401,29 @@ int PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index, | |||||
| // raise precision | // raise precision | ||||
| int selected_index = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, | int selected_index = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, | ||||
| kernel_support_datatype, kernel_match_datatype_idx); | kernel_support_datatype, kernel_match_datatype_idx); | ||||
| if (selected_index != -1) { | |||||
| int max_match = 0; | |||||
| auto iter = kernel_match_datatype_idx->begin(); | |||||
| int match_count = 0; | |||||
| while (iter != kernel_match_datatype_idx->end()) { | |||||
| auto kernel_datatypes = kernel_support_datatype.find(iter->first); | |||||
| if (kernel_datatypes == kernel_support_datatype.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find kernel index" << iter->first << "'s datatype."; | |||||
| } | |||||
| if (kernel_datatypes->second.size() < node_mix_precision_datatype.size()) { | |||||
| MS_LOG(EXCEPTION) << "Kernel datatype size is not equal to node datatype size!"; | |||||
| } | |||||
| for (size_t i = 0; i < node_mix_precision_datatype.size(); ++i) { | |||||
| if (node_mix_precision_datatype[i] == kernel_datatypes->second[i]) { | |||||
| ++match_count; | |||||
| } | |||||
| } | |||||
| if (match_count > max_match) { | |||||
| selected_index = SizeToInt(iter->first); | |||||
| } | |||||
| ++iter; | |||||
| } | |||||
| } | |||||
| if (selected_index == -1 && context_ptr->enable_reduce_precision()) { | if (selected_index == -1 && context_ptr->enable_reduce_precision()) { | ||||
| selected_index = | selected_index = | ||||
| RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, | RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, | ||||
| @@ -507,9 +472,6 @@ void SelectKernelInfo(const CNodePtr &kernel_node) { | |||||
| kernel::KernelQuery(kernel_node, &kernel_info_list); | kernel::KernelQuery(kernel_node, &kernel_info_list); | ||||
| std::vector<int> most_match_counts = {-1, -1, -1, -1}; | std::vector<int> most_match_counts = {-1, -1, -1, -1}; | ||||
| int selected_index = -1; | int selected_index = -1; | ||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| bool auto_mixed_precision = context_ptr->auto_mixed_precision_flag(); | |||||
| std::unordered_map<size_t, std::vector<int>> kernel_match_datatype_idx; | std::unordered_map<size_t, std::vector<int>> kernel_match_datatype_idx; | ||||
| std::unordered_map<size_t, std::vector<TypeId>> kernel_support_datatype; | std::unordered_map<size_t, std::vector<TypeId>> kernel_support_datatype; | ||||
| std::vector<int> node_mix_precision_datatype_index; | std::vector<int> node_mix_precision_datatype_index; | ||||
| @@ -517,16 +479,13 @@ void SelectKernelInfo(const CNodePtr &kernel_node) { | |||||
| for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { | for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { | ||||
| std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0}; | std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0}; | ||||
| auto kernel_build_info = *(kernel_info_list[info_index]); | auto kernel_build_info = *(kernel_info_list[info_index]); | ||||
| if (!IsValidKernelInfo(kernel_node, kernel_build_info)) { | |||||
| continue; | |||||
| } | |||||
| std::vector<int> support_indexes; | std::vector<int> support_indexes; | ||||
| std::vector<TypeId> support_datatypes; | std::vector<TypeId> support_datatypes; | ||||
| AddNodeAndKernelDataType(kernel_node, kernel_build_info, &support_indexes, &node_mix_precision_datatype, | AddNodeAndKernelDataType(kernel_node, kernel_build_info, &support_indexes, &node_mix_precision_datatype, | ||||
| &support_datatypes, &node_mix_precision_datatype_index); | &support_datatypes, &node_mix_precision_datatype_index); | ||||
| kernel_match_datatype_idx[info_index] = support_indexes; | kernel_match_datatype_idx[info_index] = support_indexes; | ||||
| kernel_support_datatype[info_index] = support_datatypes; | kernel_support_datatype[info_index] = support_datatypes; | ||||
| if (!auto_mixed_precision && !MatchInferOutputDataType(kernel_node, kernel_build_info)) { | |||||
| if (!MatchInferOutputDataType(kernel_node, kernel_build_info)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index]; | std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index]; | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | #include <map> | ||||
| #include <set> | |||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "kernel/oplib/oplib.h" | #include "kernel/oplib/oplib.h" | ||||
| @@ -510,6 +511,64 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | |||||
| const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, | |||||
| kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, | |||||
| kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, | |||||
| kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04}; | |||||
| // if format is default, it remarkes support all format | |||||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||||
| MS_LOG(EXCEPTION) << "Got the unknown format " << format; | |||||
| } | |||||
| if (format == kOpFormat_DEFAULT) { | |||||
| return true; | |||||
| } | |||||
| // if shape size is 0, the shape will be a scalar | |||||
| if (shape.empty()) { | |||||
| return true; | |||||
| } | |||||
| if (shape.size() > kShapeSupportFormatMap.size()) { | |||||
| return false; | |||||
| } | |||||
| if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) { | |||||
| return true; | |||||
| } | |||||
| return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end()); | |||||
| } | |||||
| bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool { | |||||
| if (!IsShapeMatchFormat(shape, format)) { | |||||
| return false; | |||||
| } | |||||
| for (auto shape_value : shape) { | |||||
| if (shape_value == 0) { | |||||
| MS_LOG(EXCEPTION) << "Dimension size of the tensor shape should be a positive integer, but got " << shape_value; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| }; | |||||
| for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); | |||||
| if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); | |||||
| if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { | |||||
| return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && | |||||
| AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { | void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | MS_EXCEPTION_IF_NULL(kernel_info_list); | ||||
| @@ -534,7 +593,7 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke | |||||
| if (context_ptr->execution_mode() == kPynativeMode) { | if (context_ptr->execution_mode() == kPynativeMode) { | ||||
| kernel_info_list->push_back(parse_info); | kernel_info_list->push_back(parse_info); | ||||
| } else { | } else { | ||||
| if (CheckSupported(kernel_node, parse_info)) { | |||||
| if (IsValidKernelInfo(kernel_node, *(parse_info)) && CheckSupported(kernel_node, parse_info)) { | |||||
| kernel_info_list->push_back(parse_info); | kernel_info_list->push_back(parse_info); | ||||
| } else { | } else { | ||||
| MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info."; | MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info."; | ||||
| @@ -37,6 +37,7 @@ | |||||
| #include "pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h" | #include "pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h" | ||||
| #include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" | #include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" | ||||
| #include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" | #include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" | ||||
| #include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h" | #include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h" | ||||
| #include "pre_activate/ascend/ir_fusion/transdata_split.h" | #include "pre_activate/ascend/ir_fusion/transdata_split.h" | ||||
| #include "pre_activate/ascend/ir_fission/topk_split.h" | #include "pre_activate/ascend/ir_fission/topk_split.h" | ||||
| @@ -243,6 +244,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | auto optimizer = std::make_shared<GraphOptimizer>(); | ||||
| auto other_pm = std::make_shared<PassManager>("other_pm"); | auto other_pm = std::make_shared<PassManager>("other_pm"); | ||||
| other_pm->AddPass(std::make_shared<AllReduceFusion>()); | other_pm->AddPass(std::make_shared<AllReduceFusion>()); | ||||
| other_pm->AddPass(std::make_shared<ParameterTransOpFusion>()); | |||||
| other_pm->AddPass(std::make_shared<BufferFusion>()); | other_pm->AddPass(std::make_shared<BufferFusion>()); | ||||
| other_pm->AddPass(std::make_shared<GetitemTuple>()); | other_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| other_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | other_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | ||||
| @@ -0,0 +1,120 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h" | |||||
| #include <memory> | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "utils/utils.h" | |||||
| #include "operator/ops.h" | |||||
| #include "device/kernel_info.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| #include "pre_activate/ascend/ascend_helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool first_flag, | |||||
| std::vector<CNodePtr> *trans_road) { | |||||
| if (node == nullptr) { | |||||
| MS_LOG(ERROR) << "nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| if (node->isa<CNode>()) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||||
| auto manager = func_graph->manager(); | |||||
| if (manager == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() || | |||||
| op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) { | |||||
| auto users = manager->node_users()[node]; | |||||
| if (users.size() > 1 && !first_flag) { | |||||
| return nullptr; | |||||
| } | |||||
| trans_road->push_back(cnode); | |||||
| first_flag = false; | |||||
| auto next_node = AnfAlgo::GetInputNode(cnode, 0); | |||||
| if (next_node->isa<Parameter>() || next_node->isa<ValueNode>()) { | |||||
| return next_node; | |||||
| } | |||||
| return ParamTransRoad(func_graph, next_node, first_flag, trans_road); | |||||
| } | |||||
| } else if (node->isa<Parameter>() || node->isa<ValueNode>()) { | |||||
| return node; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { | |||||
| if (func_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "Func graph is nullptr"; | |||||
| return false; | |||||
| } | |||||
| auto manager = func_graph->manager(); | |||||
| if (manager == nullptr) { | |||||
| return false; | |||||
| } | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return()); | |||||
| bool changed = false; | |||||
| for (auto node : node_list) { | |||||
| if (node == nullptr || !node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto node_name = AnfAlgo::GetCNodeName(cnode); | |||||
| if (node_name == prim::kPrimCast->name() || node_name == prim::kPrimTranspose->name() || | |||||
| node_name == prim::kPrimReshape->name() || node_name == kTransDataOpName) { | |||||
| MS_LOG(DEBUG) << "Skip trans op"; | |||||
| continue; | |||||
| } | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { | |||||
| std::vector<CNodePtr> trans_road; | |||||
| bool first_flag = true; | |||||
| auto final_node = ParamTransRoad(func_graph, AnfAlgo::GetInputNode(cnode, input_index), first_flag, &trans_road); | |||||
| if (final_node != nullptr && trans_road.size() == 3 && AnfAlgo::GetCNodeName(trans_road[0]) == kTransDataOpName && | |||||
| AnfAlgo::GetCNodeName(trans_road[1]) == prim::kPrimCast->name() && | |||||
| AnfAlgo::GetCNodeName(trans_road[2]) == kTransDataOpName) { | |||||
| auto cur_transop = trans_road[0]; | |||||
| auto format = AnfAlgo::GetOutputFormat(cur_transop, 0); | |||||
| auto dtype = AnfAlgo::GetOutputDeviceDataType(cur_transop, 0); | |||||
| auto param_format = AnfAlgo::GetOutputFormat(final_node, 0); | |||||
| auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); | |||||
| auto cast = trans_road[1]; | |||||
| auto cast_format = AnfAlgo::GetOutputFormat(cast, 0); | |||||
| auto cast_build_info = cast->kernel_info()->select_kernel_build_info(); | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| builder.SetOutputsFormat({format}); | |||||
| builder.SetInputsFormat({format}); | |||||
| builder.SetInputsDeviceType({param_dtype}); | |||||
| builder.SetOutputsDeviceType({dtype}); | |||||
| builder.SetKernelType(cast_build_info->kernel_type()); | |||||
| builder.SetFusionType(cast_build_info->fusion_type()); | |||||
| builder.SetProcessor(cast_build_info->processor()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); | |||||
| if (param_format == format && param_dtype != dtype) { | |||||
| manager->Replace(trans_road[2], final_node); | |||||
| manager->Replace(cur_transop, cast); | |||||
| } | |||||
| changed = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <memory> | |||||
| #include "ir/anf.h" | |||||
| #include "pre_activate/common/pass.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class ParameterTransOpFusion : public Pass { | |||||
| public: | |||||
| explicit ParameterTransOpFusion(size_t groups = 1) : Pass("Parameter_and_transop_fusion"), groups_(groups) {} | |||||
| ~ParameterTransOpFusion() override = default; | |||||
| bool Run(const FuncGraphPtr &graph) override; | |||||
| private: | |||||
| size_t groups_ = 1; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif | |||||
| @@ -44,6 +44,12 @@ cast_op_info = TBERegOp("Cast") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.U8_Default) \ | .dtype_format(DataType.F16_Default, DataType.U8_Default) \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F32_Default) \ | .dtype_format(DataType.F16_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default) \ | .dtype_format(DataType.F16_Default, DataType.I32_Default) \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F16_FracZ, DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F32_FracNZ) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F16_FracZ) \ | |||||
| .dtype_format(DataType.F32_FracNZ, DataType.F16_FracNZ) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F16_Default) \ | .dtype_format(DataType.F32_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default) \ | .dtype_format(DataType.F32_Default, DataType.I32_Default) \ | ||||
| .get_op_info() | .get_op_info() | ||||