Merge pull request !279 from changzherui/fix_fold_loctags/v0.6.0-beta
| @@ -1,312 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/aicpu/aicpu_kernel_build.h" | |||||
| #include <google/protobuf/text_format.h> | |||||
| #include <fstream> | |||||
| #include <utility> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <algorithm> | |||||
| #include <map> | |||||
| #include "device/kernel_runtime.h" | |||||
| #include "kernel/aicpu/aicpu_kernel_mod.h" | |||||
| #include "kernel/akg/akg_kernel_build.h" | |||||
| #include "proto/tensor.pb.h" | |||||
| #include "proto/tensor_shape.pb.h" | |||||
| #include "proto/attr.pb.h" | |||||
| #include "proto/node_def.pb.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "common/utils.h" | |||||
| #include "kernel/aicpu/aicpu_util.h" | |||||
| #include "session/kernel_graph.h" | |||||
| #include "kernel/common_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| using FNodeAttrHandle = std::function<void(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto)>; | |||||
| bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input_num, | |||||
| std::vector<size_t> *input_size_list) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(input_size_list); | |||||
| for (size_t i = 0; i < input_num; i++) { | |||||
| std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i); | |||||
| if (AnfAlgo::GetInputDeviceDataType(anf_node, i) == kObjectTypeString) { | |||||
| if (!anf_node->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "anf_node is not CNode."; | |||||
| } | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (cnode->inputs().size() < (i + 1)) { | |||||
| MS_LOG(ERROR) << "cnode inputs size " << cnode->inputs().size() << " is smaller than " << i + 1; | |||||
| return false; | |||||
| } | |||||
| auto input_node = cnode->inputs()[i + 1]; | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| if (input_node->isa<ValueNode>()) { | |||||
| auto value_ptr = GetValueNode(input_node); | |||||
| auto value = GetValue<std::string>(value_ptr); | |||||
| input_size_list->push_back(value.size()); | |||||
| } | |||||
| } else { | |||||
| auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i)); | |||||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||||
| int64_t size_i = 1; | |||||
| for (size_t j = 0; j < shape_i.size(); j++) { | |||||
| size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j])); | |||||
| } | |||||
| size_t type_byte = GetTypeByte(type_ptr); | |||||
| if (type_byte == 0) { | |||||
| return false; | |||||
| } | |||||
| size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); | |||||
| input_size_list->push_back(LongToSize(size_i)); | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<AicpuOpKernelMod> &kernel_mod_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod_ptr); | |||||
| std::vector<size_t> input_size_list; | |||||
| std::vector<size_t> output_size_list; | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); | |||||
| if (!SetIOIputSize(anf_node, input_num, &input_size_list)) { | |||||
| return false; | |||||
| } | |||||
| kernel_mod_ptr->SetInputSizeList(input_size_list); | |||||
| for (size_t i = 0; i < output_num; i++) { | |||||
| std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); | |||||
| TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i)); | |||||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||||
| int64_t size_i = 1; | |||||
| for (size_t j = 0; j < shape_i.size(); j++) { | |||||
| size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j])); | |||||
| } | |||||
| size_t type_byte = GetTypeByte(type_ptr); | |||||
| if (type_byte == 0) { | |||||
| return false; | |||||
| } | |||||
| size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); | |||||
| output_size_list.push_back(LongToSize(size_i)); | |||||
| } | |||||
| kernel_mod_ptr->SetOutputSizeList(output_size_list); | |||||
| return true; | |||||
| } | |||||
| void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value, | |||||
| ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) { | |||||
| MS_EXCEPTION_IF_NULL(node_attr); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| if (type == "int") { | |||||
| auto attr_value = GetValue<int>(value); | |||||
| (*node_attr)[attr_name].set_i(attr_value); | |||||
| } else if (type == "str") { | |||||
| auto attr_value = GetValue<std::string>(value); | |||||
| (*node_attr)[attr_name].set_s(attr_value); | |||||
| } else if (type == "bool") { | |||||
| auto attr_value = GetValue<bool>(value); | |||||
| (*node_attr)[attr_name].set_b(attr_value); | |||||
| } else if (type == "float") { | |||||
| auto attr_value = GetValue<float>(value); | |||||
| (*node_attr)[attr_name].set_f(attr_value); | |||||
| } else if (type == "listInt") { | |||||
| std::vector<int> attr_value; | |||||
| auto value_type = value->type(); | |||||
| MS_EXCEPTION_IF_NULL(value_type); | |||||
| auto value_type_str = value_type->ToString(); | |||||
| if (value_type_str == "Int32") { | |||||
| int data = GetValue<int>(value); | |||||
| attr_value.push_back(data); | |||||
| } else { | |||||
| attr_value = GetValue<std::vector<int>>(value); | |||||
| } | |||||
| mindspore::AttrValue input_shape_attr; | |||||
| mindspore::AttrValue_ArrayValue *input_shape_attr_list = input_shape_attr.mutable_array(); | |||||
| MS_EXCEPTION_IF_NULL(input_shape_attr_list); | |||||
| for (const auto shape : attr_value) { | |||||
| input_shape_attr_list->add_i(shape); | |||||
| } | |||||
| (*node_attr)[attr_name] = input_shape_attr; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "type: " << type << "not support"; | |||||
| } | |||||
| } | |||||
| void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(proto); | |||||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||||
| if (op_name == kInitDataSetQueue) { | |||||
| op_name = kInitData; | |||||
| } | |||||
| if (op_name == kPrint) { | |||||
| return; | |||||
| } | |||||
| auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU); | |||||
| MS_EXCEPTION_IF_NULL(op_info_ptr); | |||||
| auto attrs_ptr = op_info_ptr->attrs_ptr(); | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); | |||||
| for (const auto &attr_ptr : attrs_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(attr_ptr); | |||||
| std::string attr_name = attr_ptr->name(); | |||||
| auto value = primitive->GetAttr(attr_name); | |||||
| if (value != nullptr) { | |||||
| if (attr_name == kQueueName || attr_name == kSharedName) { | |||||
| attr_name = kChannelName; | |||||
| } else if (attr_name == kSeed0) { | |||||
| attr_name = kSeed; | |||||
| } else if (attr_name == kSeed1) { | |||||
| attr_name = kSeed2; | |||||
| } | |||||
| std::string type = attr_ptr->type(); | |||||
| ParseAttrValue(type, attr_name, value, node_attr); | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "Set node attr end!"; | |||||
| } | |||||
| void SetNodeInputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) { | |||||
| MS_EXCEPTION_IF_NULL(proto); | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); | |||||
| if (input_num == 0) { | |||||
| MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input."; | |||||
| return; | |||||
| } | |||||
| for (size_t input_index = 0; input_index < input_num; input_index++) { | |||||
| ::mindspore::Tensor *node_inputs = proto->add_inputs(); | |||||
| MS_EXCEPTION_IF_NULL(node_inputs); | |||||
| TypeId input_type = AnfAlgo::GetInputDeviceDataType(anf_node, input_index); | |||||
| std::vector<size_t> input_shape; | |||||
| int32_t input_data_type; | |||||
| if (input_type == kObjectTypeString) { | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto input_node = cnode->inputs()[input_index + 1]; | |||||
| auto value_ptr = GetValueNode(input_node); | |||||
| auto value = GetValue<std::string>(value_ptr); | |||||
| input_shape.push_back(1); | |||||
| input_shape.push_back(value.size()); | |||||
| input_data_type = AicpuOpUtil::MsTypeToProtoType(kTypeUnknown); | |||||
| } else { | |||||
| input_shape = AnfAlgo::GetInputDeviceShape(anf_node, input_index); | |||||
| input_data_type = AicpuOpUtil::MsTypeToProtoType(input_type); | |||||
| } | |||||
| mindspore::TensorShape *tensorShape = node_inputs->mutable_tensor_shape(); | |||||
| for (auto item : input_shape) { | |||||
| mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); | |||||
| dim->set_size((::google::protobuf::int64)item); | |||||
| } | |||||
| node_inputs->set_tensor_type((mindspore::DataType)input_data_type); | |||||
| node_inputs->set_mem_device("HBM"); | |||||
| } | |||||
| } | |||||
| void SetNodeOutputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) { | |||||
| MS_EXCEPTION_IF_NULL(proto); | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); | |||||
| if (output_num == 0) { | |||||
| MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have output. "; | |||||
| return; | |||||
| } | |||||
| for (size_t output_index = 0; output_index < output_num; output_index++) { | |||||
| ::mindspore::Tensor *node_outputs = proto->add_outputs(); | |||||
| MS_EXCEPTION_IF_NULL(node_outputs); | |||||
| std::vector<size_t> output_shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index); | |||||
| mindspore::TensorShape *tensorShape = node_outputs->mutable_tensor_shape(); | |||||
| MS_EXCEPTION_IF_NULL(tensorShape); | |||||
| for (auto item : output_shape) { | |||||
| mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); | |||||
| MS_EXCEPTION_IF_NULL(dim); | |||||
| dim->set_size((::google::protobuf::int64)item); | |||||
| } | |||||
| TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, output_index); | |||||
| int32_t output_data_type = AicpuOpUtil::MsTypeToProtoType(output_type); | |||||
| node_outputs->set_tensor_type((mindspore::DataType)output_data_type); | |||||
| node_outputs->set_mem_device("HBM"); | |||||
| } | |||||
| } | |||||
| void SetNodedefProto(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(proto); | |||||
| MS_LOG(INFO) << "SetNodedefProto entry"; | |||||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||||
| if (op_name == kInitDataSetQueue) { | |||||
| op_name = kInitData; | |||||
| } | |||||
| // set op name | |||||
| proto->set_op(op_name); | |||||
| // set inputs tensor | |||||
| SetNodeInputs(anf_node, proto); | |||||
| // set outputs tensor | |||||
| SetNodeOutputs(anf_node, proto); | |||||
| // set node attr | |||||
| SetNodeAttr(anf_node, proto); | |||||
| MS_LOG(INFO) << "SetNodedefProto end!"; | |||||
| } | |||||
| bool CreateNodeDefBytes(const std::shared_ptr<AnfNode> &anf_node, | |||||
| const std::shared_ptr<AicpuOpKernelMod> &kernel_mod_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod_ptr); | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| MS_LOG(INFO) << "CreateNodeDefBytes entry"; | |||||
| mindspore::NodeDef proto; | |||||
| SetNodedefProto(anf_node, &proto); | |||||
| std::string nodeDefStr; | |||||
| if (!proto.SerializeToString(&nodeDefStr)) { | |||||
| MS_LOG(ERROR) << "Serialize nodeDef to string failed."; | |||||
| return false; | |||||
| } | |||||
| kernel_mod_ptr->SetNodeDef(nodeDefStr); | |||||
| MS_LOG(INFO) << "CreateNodeDefBytes end!"; | |||||
| return true; | |||||
| } | |||||
| KernelModPtr AicpuOpBuild(const std::shared_ptr<AnfNode> &anf_node) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||||
| if (op_name == kInitDataSetQueue) { | |||||
| op_name = kInitData; | |||||
| } | |||||
| auto kernel_mod_ptr = std::make_shared<AicpuOpKernelMod>(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod_ptr); | |||||
| kernel_mod_ptr->SetAnfNode(anf_node); | |||||
| kernel_mod_ptr->SetNodeName(op_name); | |||||
| if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) { | |||||
| MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!"; | |||||
| } | |||||
| if (!SetIOSize(anf_node, kernel_mod_ptr)) { | |||||
| MS_LOG(EXCEPTION) << "Set input output size list failed."; | |||||
| } | |||||
| return kernel_mod_ptr; | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -1,65 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ | |||||
| #define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ | |||||
| #include <cstdint> | |||||
| #include <vector> | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include "kernel/kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| constexpr auto kInitDataSetQueue = "InitDataSetQueue"; | |||||
| constexpr auto kInitData = "InitData"; | |||||
| constexpr auto kGetNext = "GetNext"; | |||||
| constexpr auto kPrint = "Print"; | |||||
| constexpr auto kPack = "Pack"; | |||||
| constexpr auto kOutputTypes = "output_types"; | |||||
| constexpr auto kOutputShapes = "output_shapes"; | |||||
| constexpr auto kChannelName = "channel_name"; | |||||
| constexpr auto kSharedName = "shared_name"; | |||||
| constexpr auto kShapes = "shapes"; | |||||
| constexpr auto kTypes = "types"; | |||||
| constexpr auto kQueueName = "queue_name"; | |||||
| constexpr auto kSeed = "seed"; | |||||
| constexpr auto kSeed0 = "Seed0"; | |||||
| constexpr auto kSeed1 = "Seed1"; | |||||
| constexpr auto kSeed2 = "seed2"; | |||||
| constexpr auto kTopK = "TopK"; | |||||
| constexpr auto kTopKV2 = "TopKV2"; | |||||
| struct AicpuParamHead { | |||||
| uint32_t length; // Total length: include cunstom message | |||||
| uint32_t ioAddrNum; // Input and output address number | |||||
| uint32_t extInfoLength; // extInfo struct Length | |||||
| uint64_t extInfoAddr; // extInfo address | |||||
| } __attribute__((packed)); | |||||
| class AicpuOpUtil { | |||||
| public: | |||||
| static int MsTypeToProtoType(TypeId ms_type); | |||||
| private: | |||||
| // kernel id | |||||
| static uint64_t KernelId_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ | |||||
| @@ -1,622 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <set> | |||||
| #include <utility> | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "kernel/oplib/oplib.h" | |||||
| #include "kernel/tbe/tbe_kernel_build.h" | |||||
| #include "nlohmann/json.hpp" | |||||
| #include "utils/context/ms_context.h" | |||||
| #include "kernel/tbe/tbe_python_funcs.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| #include "kernel/tbe/tbe_convert_utils.h" | |||||
| #include "parallel/ops_info/ops_utils.h" | |||||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" | |||||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" | |||||
| #include "kernel/tbe/tbe_kernel_select/common_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| constexpr auto kName = "name"; | |||||
| constexpr auto kDtype = "dtype"; | |||||
| constexpr auto kFormat = "format"; | |||||
| constexpr auto kPrefixInput = "input"; | |||||
| constexpr auto kPrefixOutput = "output"; | |||||
| constexpr char kParamTypeDynamic[] = "dynamic"; | |||||
| constexpr char kParamTypeRequre[] = "required"; | |||||
| constexpr char kParamTypeOptional[] = "optional"; | |||||
| void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { | |||||
| auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list); | |||||
| tbe_selecter.TbeMetadataInfoEx(); | |||||
| } | |||||
| TbeKernelSelect::TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) | |||||
| : cnode_ptr_(std::move(kernel_node)), kernel_info_list_(kernel_info_list) {} | |||||
| void TbeKernelSelect::TbeMetadataInfoEx() { | |||||
| MS_EXCEPTION_IF_NULL(cnode_ptr_); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info_list_); | |||||
| node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_); | |||||
| auto op_info_ptr = OpLib::FindOp(node_name_, kTBE); | |||||
| if (!op_info_ptr) { | |||||
| MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Start to tbe metadata info. node type: " << node_name_ | |||||
| << ", node name: " << cnode_ptr_->fullname_with_scope(); | |||||
| OpPattern pattern = op_info_ptr->op_pattern(); | |||||
| if (pattern == kCommonPattern) { | |||||
| GetCommonPatternKernelInfo(*op_info_ptr); | |||||
| } else if (pattern == kDynamicFormatPattern) { | |||||
| GetDynamicFormatPatternKernelInfo(*op_info_ptr); | |||||
| } else if (pattern == kFormatAgnosticPattern) { | |||||
| GetAgnosticPatternKernelInfo(*op_info_ptr); | |||||
| } else if (pattern == kBroadcastPattern) { | |||||
| GetBroadcastPatternKernelInfo(*op_info_ptr); | |||||
| } else if (pattern == kReducePattern) { | |||||
| GetReducePatternKernelInfo(*op_info_ptr); | |||||
| } else { | |||||
| MS_LOG(INFO) << "Warning: op pattern is invailed."; | |||||
| } | |||||
| // check support | |||||
| FilterInVaildKernelInfo(); | |||||
| MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; | |||||
| } | |||||
| void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { | |||||
| MS_LOG(INFO) << "start."; | |||||
| // get dynamic inputs | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| std::vector<int> dyn_input_sizes; | |||||
| if (primitive->HasAttr(kAttrDynInputSizes)) { | |||||
| dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes)); | |||||
| } | |||||
| // get real input/output num | |||||
| size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); | |||||
| const auto inputs_info = op_info.inputs_ptr(); | |||||
| size_t real_output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); | |||||
| const auto outputs_info = op_info.outputs_ptr(); | |||||
| if (inputs_info.empty() && outputs_info.empty()) { | |||||
| MS_LOG(EXCEPTION) << "op info input & output is null, please check."; | |||||
| } | |||||
| // create kernel build info from opinfo | |||||
| size_t kernel_build_info_num = | |||||
| inputs_info.empty() ? outputs_info[0]->dtypes().size() : inputs_info[0]->dtypes().size(); | |||||
| for (size_t kernel_build_info_index = 0; kernel_build_info_index < kernel_build_info_num; ++kernel_build_info_index) { | |||||
| auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); | |||||
| SetTbeBuildCommonInfo(op_info, &builder); | |||||
| std::vector<std::string> inputs_format; | |||||
| std::vector<TypeId> inputs_device_type; | |||||
| std::vector<std::vector<Axis>> inputs_reshape_type; | |||||
| // input | |||||
| if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes, | |||||
| &inputs_format, &inputs_device_type, &inputs_reshape_type)) { | |||||
| break; | |||||
| } | |||||
| builder.SetInputsDeviceType(inputs_device_type); | |||||
| builder.SetInputsFormat(inputs_format); | |||||
| builder.SetInputReshapeType(inputs_reshape_type); | |||||
| // output | |||||
| std::vector<std::string> outputs_format; | |||||
| std::vector<TypeId> outputs_device_type; | |||||
| std::vector<std::vector<Axis>> outputs_reshape_type; | |||||
| if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes, | |||||
| &outputs_format, &outputs_device_type, &outputs_reshape_type)) { | |||||
| break; | |||||
| } | |||||
| builder.SetOutputsDeviceType(outputs_device_type); | |||||
| builder.SetOutputsFormat(outputs_format); | |||||
| builder.SetOutputReshapeType(outputs_reshape_type); | |||||
| kernel_info_list_->emplace_back(builder.Build()); | |||||
| } | |||||
| MS_LOG(INFO) << "end."; | |||||
| } | |||||
| void TbeKernelSelect::GetDynamicFormatPatternKernelInfo(const OpInfo &op_info) { | |||||
| MS_LOG(INFO) << "start."; | |||||
| // | |||||
| OpInfo op_info_new; | |||||
| CreateNewOpInfo(op_info, &op_info_new); | |||||
| GetCommonPatternKernelInfo(op_info_new); | |||||
| MS_LOG(INFO) << "end."; | |||||
| } | |||||
| void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) { | |||||
| MS_LOG(INFO) << "start."; | |||||
| if (op_info.inputs_ptr().size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "AgnosticPattern only support one input."; | |||||
| } | |||||
| auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0); | |||||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||||
| MS_LOG(INFO) << "Got the unknown format " << format; | |||||
| format = kOpFormat_DEFAULT; | |||||
| } | |||||
| SupportFormat support_format; | |||||
| SupportFormatItem input_item; | |||||
| SupportFormatItem output_item; | |||||
| input_item.assign(op_info.inputs_ptr().size(), format); | |||||
| output_item.assign(op_info.outputs_ptr().size(), format); | |||||
| support_format.input_format.emplace_back(input_item); | |||||
| support_format.output_format.emplace_back(output_item); | |||||
| PrintSupportedFormat(support_format); | |||||
| OpInfo op_info_new; | |||||
| CreateNewOpInfo(op_info, support_format, &op_info_new); | |||||
| GetCommonPatternKernelInfo(op_info_new); | |||||
| MS_LOG(INFO) << "end."; | |||||
| } | |||||
| void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) { | |||||
| MS_LOG(INFO) << "start."; | |||||
| auto broadcast_selecter = TbeKernelBroadCastSelecter(cnode_ptr_); | |||||
| SupportFormat support_format; | |||||
| broadcast_selecter.GetShapeInfo(&support_format); | |||||
| if (!broadcast_selecter.IsBroadCastSupport5HD(&support_format)) { | |||||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support 5HD."; | |||||
| } | |||||
| if (!broadcast_selecter.IsBroadCastSupportFracZ(&support_format)) { | |||||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracZ."; | |||||
| } | |||||
| if (!broadcast_selecter.IsBroadCastSupportC1HWNCoC0(&support_format)) { | |||||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support C1HWNCoC0."; | |||||
| } | |||||
| if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) { | |||||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ."; | |||||
| } | |||||
| PrintSupportedFormat(support_format); | |||||
| OpInfo op_info_new; | |||||
| CreateNewOpInfo(op_info, support_format, &op_info_new); | |||||
| GetCommonPatternKernelInfo(op_info_new); | |||||
| MS_LOG(INFO) << "end."; | |||||
| } | |||||
| void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) { | |||||
| MS_LOG(INFO) << "start."; | |||||
| auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_); | |||||
| SupportFormat support_format; | |||||
| reduce_selecter.GetShapeInfo(&support_format); | |||||
| if (!reduce_selecter.IsReduceSupport5HD(&support_format)) { | |||||
| MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support 5HD."; | |||||
| } | |||||
| if (reduce_selecter.IsReduceSupportFracZ(&support_format)) { | |||||
| MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracZ."; | |||||
| } | |||||
| if (reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format)) { | |||||
| MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support C1HWNCoC0."; | |||||
| } | |||||
| if (reduce_selecter.IsReduceSupportFracNZ(&support_format)) { | |||||
| MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracNZ."; | |||||
| } | |||||
| PrintSupportedFormat(support_format); | |||||
| OpInfo op_info_new; | |||||
| CreateNewOpInfo(op_info, support_format, &op_info_new); | |||||
| GetCommonPatternKernelInfo(op_info_new); | |||||
| MS_LOG(INFO) << "end."; | |||||
| } | |||||
| void TbeKernelSelect::FilterInVaildKernelInfo() { | |||||
| if (kernel_info_list_->empty()) { | |||||
| MS_LOG(INFO) << "Warning: get kernel build info failed."; | |||||
| return; | |||||
| } | |||||
| auto kernel_build_info_iter = kernel_info_list_->begin(); | |||||
| while (kernel_build_info_iter != kernel_info_list_->end()) { | |||||
| if (!FilterInVaildShape(kernel_build_info_iter)) { | |||||
| MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*kernel_build_info_iter)->ToString(); | |||||
| kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); | |||||
| continue; | |||||
| } | |||||
| if (!TbeCheckSupported(kernel_build_info_iter)) { | |||||
| MS_LOG(INFO) << "Check support shape, filter item info: " << (*kernel_build_info_iter)->ToString(); | |||||
| kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); | |||||
| continue; | |||||
| } | |||||
| kernel_build_info_iter++; | |||||
| } | |||||
| } | |||||
| bool TbeKernelSelect::FilterInVaildShape( | |||||
| const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { | |||||
| MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); | |||||
| auto kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); | |||||
| for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { | |||||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); | |||||
| auto format = kernel_build_info_inputs_format.at(i); | |||||
| if (!IsShapeMatchFormat(shape, format)) { | |||||
| MS_LOG(INFO) << "The " << i << "th input check failed."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| auto kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats(); | |||||
| for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) { | |||||
| auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j); | |||||
| auto format = kernel_build_info_outputs_format.at(j); | |||||
| if (!IsShapeMatchFormat(shape, format)) { | |||||
| MS_LOG(INFO) << "The " << j << "th input check failed."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | |||||
| if (format == kOpFormat_DEFAULT) { | |||||
| return true; | |||||
| } | |||||
| static std::set<std::string> kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; | |||||
| // if format is default, it remarkes support all format | |||||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||||
| MS_LOG(EXCEPTION) << "Got the unknown format " << format; | |||||
| } | |||||
| // server not support format with C04 suffix | |||||
| if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) != | |||||
| kServerNotSupportFormat.end()) { | |||||
| MS_LOG(INFO) << "Warning: Server not support format with C04 suffix."; | |||||
| return false; | |||||
| } | |||||
| // not support format: | |||||
| // 1 NDHWC with shape size != 5 | |||||
| // 2 FRAC_NZ with shape size < 2 | |||||
| // 3 !NDHWC with shape size > 4 | |||||
| if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) || | |||||
| (format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) || | |||||
| (format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) { | |||||
| MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool TbeKernelSelect::TbeCheckSupported( | |||||
| const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { | |||||
| MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); | |||||
| static const std::set<std::string> kCheckSupportedOpType = {parallel::MATMUL, | |||||
| parallel::BATCHMATMUL, | |||||
| parallel::TOPK, | |||||
| parallel::IN_TOPK, | |||||
| parallel::PACK, | |||||
| parallel::UNSORTEF_SEGMENT_MIND, | |||||
| parallel::UNSORTEF_SEGMENT_PRODD, | |||||
| parallel::CAST}; | |||||
| auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_); | |||||
| if (iter == kCheckSupportedOpType.end()) { | |||||
| return true; | |||||
| } | |||||
| MS_LOG(INFO) << "Check support start."; | |||||
| // replace kernel_info with current kernel info | |||||
| auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get()); | |||||
| nlohmann::json kernel_json; | |||||
| TbeKernelJsonCreator creator(CHECK_SUPPORTED); | |||||
| bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); | |||||
| if (!ret) { | |||||
| MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed."; | |||||
| } | |||||
| ret = TbePythonFuncs::CheckSupported(kernel_json); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get()); | |||||
| return ret; | |||||
| } | |||||
| void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_info, | |||||
| mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder *builder) { | |||||
| MS_EXCEPTION_IF_NULL(builder); | |||||
| builder->SetProcessor(AICORE); | |||||
| std::string fusion_type = op_info.fusion_type(); | |||||
| if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { | |||||
| builder->SetFusionType(tbe::GetFusionType(fusion_type)); | |||||
| } | |||||
| builder->SetOpPattern(op_info.op_pattern()); | |||||
| builder->SetKernelType(TBE_KERNEL); | |||||
| } | |||||
| bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, | |||||
| const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, | |||||
| const std::vector<int> &dyn_input_sizes, std::vector<std::string> *formats, | |||||
| std::vector<TypeId> *device_types, std::vector<std::vector<Axis>> *reshape_types) { | |||||
| MS_EXCEPTION_IF_NULL(formats); | |||||
| MS_EXCEPTION_IF_NULL(device_types); | |||||
| MS_EXCEPTION_IF_NULL(reshape_types); | |||||
| size_t dynamic_input_index = 0; | |||||
| size_t real_io_tensor_index = 0; | |||||
| size_t io_info_index = 0; | |||||
| size_t io_info_num = ios_info.size(); | |||||
| for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) { | |||||
| std::shared_ptr<OpIOInfo> io_info_item = ios_info[io_info_index]; | |||||
| auto kernel_build_info_dtype = io_info_item->dtypes().at(kernel_build_info_index); | |||||
| std::string kernel_build_info_format; | |||||
| if (!io_info_item->formats().empty()) { | |||||
| kernel_build_info_format = io_info_item->formats().at(kernel_build_info_index); | |||||
| } | |||||
| std::string io_param_type = io_info_item->param_type(); | |||||
| std::vector<Axis> reshape_type; | |||||
| StringToAxisVector(io_info_item->reshape_type(), &reshape_type); | |||||
| if (io_param_type == kParamTypeDynamic) { | |||||
| // dynamic io | |||||
| if (is_input) { | |||||
| if (dynamic_input_index >= dyn_input_sizes.size()) { | |||||
| MS_LOG(EXCEPTION) << "dyn_input_sizes attr set error, dynamic_input_index: " << dynamic_input_index | |||||
| << ", dyn_input_sizes size: " << dyn_input_sizes.size(); | |||||
| } | |||||
| int dynamic_input_size = dyn_input_sizes[dynamic_input_index]; | |||||
| for (int i = 0; i < dynamic_input_size; ++i) { | |||||
| device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); | |||||
| formats->emplace_back(kernel_build_info_format); | |||||
| reshape_types->emplace_back(reshape_type); | |||||
| } | |||||
| dynamic_input_index++; | |||||
| real_io_tensor_index += dynamic_input_size; | |||||
| } else { | |||||
| if (ios_info.size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output."; | |||||
| } | |||||
| for (size_t i = 0; i < real_io_tensor_num; ++i) { | |||||
| device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); | |||||
| formats->emplace_back(kernel_build_info_format); | |||||
| reshape_types->emplace_back(reshape_type); | |||||
| } | |||||
| real_io_tensor_index += real_io_tensor_num; | |||||
| } | |||||
| } else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) { | |||||
| // requre or optional io | |||||
| device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); | |||||
| formats->emplace_back(kernel_build_info_format); | |||||
| reshape_types->emplace_back(reshape_type); | |||||
| real_io_tensor_index++; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type; | |||||
| } | |||||
| } | |||||
| if (io_info_index != io_info_num) { | |||||
| MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num | |||||
| << "), this node may has optional input/output."; | |||||
| } | |||||
| if (real_io_tensor_index != real_io_tensor_num) { | |||||
| std::string io_type = is_input ? "inputs " : "outputs"; | |||||
| MS_LOG(INFO) << node_name_ << "'s " << io_type << "op io info num: " << io_info_num | |||||
| << ", real io tensor num:" << real_io_tensor_num << "real_io_tensor_index(" << real_io_tensor_index | |||||
| << ") != real_io_tensor_num(" << real_io_tensor_num << ")"; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) { | |||||
| MS_EXCEPTION_IF_NULL(reshape_type_vec); | |||||
| for (const auto &c : reshape_type_str) { | |||||
| switch (c) { | |||||
| case 'N': | |||||
| reshape_type_vec->push_back(kernel::N); | |||||
| break; | |||||
| case 'C': | |||||
| reshape_type_vec->push_back(kernel::C); | |||||
| break; | |||||
| case 'H': | |||||
| reshape_type_vec->push_back(kernel::H); | |||||
| break; | |||||
| case 'W': | |||||
| reshape_type_vec->push_back(kernel::W); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; | |||||
| } | |||||
| } | |||||
| } | |||||
| void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, | |||||
| const std::vector<std::vector<std::string>> &support_format_item, size_t index, | |||||
| mindspore::kernel::OpIOInfo *op_io_info_new) { | |||||
| MS_EXCEPTION_IF_NULL(op_io_info_new); | |||||
| op_io_info_new->set_index(op_io_info.index()); | |||||
| op_io_info_new->set_name(op_io_info.name()); | |||||
| op_io_info_new->set_param_type(op_io_info.param_type()); | |||||
| op_io_info_new->set_need_compile(op_io_info.need_compile()); | |||||
| op_io_info_new->set_reshape_type(op_io_info.reshape_type()); | |||||
| op_io_info_new->set_shape(op_io_info.shape()); | |||||
| // dtype | |||||
| std::vector<std::string> dtype_new; | |||||
| auto dtype = op_io_info.dtypes(); | |||||
| for (size_t i = 0; i < support_format_item.size(); ++i) { | |||||
| dtype_new.insert(dtype_new.end(), dtype.begin(), dtype.end()); | |||||
| } | |||||
| op_io_info_new->set_dtypes(dtype_new); | |||||
| // format | |||||
| std::vector<std::string> format_new; | |||||
| for (const auto &formats : support_format_item) { | |||||
| auto format = formats.at(index); | |||||
| for (size_t j = 0; j < dtype.size(); ++j) { | |||||
| format_new.emplace_back(format); | |||||
| } | |||||
| } | |||||
| op_io_info_new->set_formats(format_new); | |||||
| } | |||||
| std::vector<std::string> TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) { | |||||
| const std::map<std::string, std::string> kDynamicFormatMap = { | |||||
| {"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}}; | |||||
| if (op_select_json_item.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Op select ret item is null."; | |||||
| } | |||||
| const char space = ' '; | |||||
| const char sep = ','; | |||||
| std::string op_select_tmp = op_select_json_item + ","; | |||||
| std::vector<std::string> ret; | |||||
| auto begin = op_select_tmp.find_first_not_of(space, 0); | |||||
| auto sep_pos = op_select_tmp.find(sep); | |||||
| if (begin >= sep_pos) { | |||||
| MS_LOG(EXCEPTION) << "Select ret json is error."; | |||||
| } | |||||
| while (sep_pos != std::string::npos) { | |||||
| auto obj = op_select_tmp.substr(begin, sep_pos - begin); | |||||
| if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) { | |||||
| obj = kDynamicFormatMap.at(obj); | |||||
| } | |||||
| ret.emplace_back(obj); | |||||
| begin = op_select_tmp.find_first_not_of(space, sep_pos + 1); | |||||
| sep_pos = op_select_tmp.find(sep, begin); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| std::string TbeKernelSelect::OpSelectFormat() { | |||||
| nlohmann::json kernel_json; | |||||
| std::string res_json_str; | |||||
| TbeKernelJsonCreator creator(OP_SELECT_FORMAT); | |||||
| bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); | |||||
| if (!ret) { | |||||
| MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed."; | |||||
| } | |||||
| res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json); | |||||
| if (res_json_str.empty()) { | |||||
| MS_LOG(EXCEPTION) << "op select format error."; | |||||
| } | |||||
| MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str; | |||||
| return res_json_str; | |||||
| } | |||||
| void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, const SupportFormat &support_format, | |||||
| mindspore::kernel::OpInfo *op_info_new) { | |||||
| MS_EXCEPTION_IF_NULL(op_info_new); | |||||
| if (op_info.inputs_ptr().size() != support_format.input_format[0].size() || | |||||
| op_info.outputs_ptr().size() != support_format.output_format[0].size()) { | |||||
| MS_LOG(EXCEPTION) << "BroadCast input/output size not match, op info input size:" << op_info.inputs_ptr().size() | |||||
| << ", input support size: " << support_format.input_format[0].size() | |||||
| << ", op info output size: " << op_info.outputs_ptr().size() | |||||
| << ", output support size: " << support_format.output_format[0].size(); | |||||
| } | |||||
| *op_info_new = op_info; | |||||
| op_info_new->ClearInputs(); | |||||
| op_info_new->ClearOutputs(); | |||||
| for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { | |||||
| auto input = op_info.inputs_ptr().at(i); | |||||
| auto input_new = std::make_shared<OpIOInfo>(); | |||||
| CreateNewOpIOInfo(*input, support_format.input_format, i, input_new.get()); | |||||
| op_info_new->add_inputs_ptr(input_new); | |||||
| } | |||||
| for (size_t j = 0; j < op_info.outputs_ptr().size(); ++j) { | |||||
| auto output = op_info.outputs_ptr().at(j); | |||||
| auto output_new = std::make_shared<OpIOInfo>(); | |||||
| CreateNewOpIOInfo(*output, support_format.output_format, j, output_new.get()); | |||||
| op_info_new->add_outputs_ptr(output_new); | |||||
| } | |||||
| } | |||||
| struct SelectOpIOInfo { | |||||
| std::string name; | |||||
| std::vector<std::string> dtypes; | |||||
| std::vector<std::string> formats; | |||||
| }; | |||||
| void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, | |||||
| mindspore::kernel::OpInfo *op_info_new) { | |||||
| MS_EXCEPTION_IF_NULL(op_info_new); | |||||
| auto op_seclect_json = OpSelectFormat(); | |||||
| if (!op_seclect_json.empty()) { | |||||
| nlohmann::json json_obj = nlohmann::json::parse(op_seclect_json); | |||||
| if (!json_obj.is_object()) { | |||||
| MS_LOG(EXCEPTION) << "JsonStr is not an object, the jsonStr is:" << op_seclect_json; | |||||
| } | |||||
| std::vector<SelectOpIOInfo> inputs; | |||||
| std::vector<SelectOpIOInfo> outputs; | |||||
| for (const auto &item : json_obj.items()) { | |||||
| const std::string &item_name = item.key(); | |||||
| bool is_input = (item_name.find(kPrefixInput) != std::string::npos); | |||||
| bool is_output = (item_name.find(kPrefixOutput) != std::string::npos); | |||||
| if (!is_input && !is_output) { | |||||
| MS_LOG(EXCEPTION) << "op select ret json is error."; | |||||
| } | |||||
| if (is_input) { | |||||
| SelectOpIOInfo select_input; | |||||
| select_input.name = item.value().at(kName); | |||||
| std::string input_dtype_item = item.value().at(kDtype); | |||||
| select_input.dtypes = SplitStrToVec(input_dtype_item); | |||||
| std::string input_format_item = item.value().at(kFormat); | |||||
| select_input.formats = SplitStrToVec(input_format_item); | |||||
| inputs.emplace_back(select_input); | |||||
| } else if (is_output) { | |||||
| SelectOpIOInfo select_output; | |||||
| select_output.name = item.value().at(kName); | |||||
| std::string input_dtype_item = item.value().at(kDtype); | |||||
| select_output.dtypes = SplitStrToVec(input_dtype_item); | |||||
| std::string input_format_item = item.value().at(kFormat); | |||||
| select_output.formats = SplitStrToVec(input_format_item); | |||||
| outputs.emplace_back(select_output); | |||||
| } | |||||
| } | |||||
| if (op_info.inputs_ptr().size() != inputs.size() || op_info.outputs_ptr().size() != outputs.size()) { | |||||
| MS_LOG(EXCEPTION) << "select format input/output size not equal, please check register."; | |||||
| } | |||||
| *op_info_new = op_info; | |||||
| op_info_new->ClearInputs(); | |||||
| op_info_new->ClearOutputs(); | |||||
| for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { | |||||
| auto input_new = std::make_shared<OpIOInfo>(); | |||||
| CreateNewOpIOInfo(*op_info.inputs_ptr().at(i), inputs.at(i).dtypes, inputs.at(i).formats, input_new.get()); | |||||
| op_info_new->add_inputs_ptr(input_new); | |||||
| } | |||||
| for (size_t i = 0; i < op_info.outputs_ptr().size(); ++i) { | |||||
| auto output_new = std::make_shared<OpIOInfo>(); | |||||
| CreateNewOpIOInfo(*op_info.outputs_ptr().at(i), outputs.at(i).dtypes, outputs.at(i).formats, output_new.get()); | |||||
| op_info_new->add_outputs_ptr(output_new); | |||||
| } | |||||
| } | |||||
| } | |||||
| void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, | |||||
| const std::vector<std::string> &support_dtype, | |||||
| const std::vector<std::string> &support_format, | |||||
| mindspore::kernel::OpIOInfo *op_io_info_new) { | |||||
| MS_EXCEPTION_IF_NULL(op_io_info_new); | |||||
| op_io_info_new->set_index(op_io_info.index()); | |||||
| op_io_info_new->set_name(op_io_info.name()); | |||||
| op_io_info_new->set_param_type(op_io_info.param_type()); | |||||
| op_io_info_new->set_need_compile(op_io_info.need_compile()); | |||||
| op_io_info_new->set_reshape_type(op_io_info.reshape_type()); | |||||
| op_io_info_new->set_shape(op_io_info.shape()); | |||||
| // dtype && format | |||||
| op_io_info_new->set_dtypes(support_dtype); | |||||
| op_io_info_new->set_formats(support_format); | |||||
| } | |||||
| void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) { | |||||
| if (support_format.input_format.size() != support_format.output_format.size()) { | |||||
| MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output(" | |||||
| << support_format.output_format.size() << ") size not match."; | |||||
| } | |||||
| for (size_t i = 0; i < support_format.input_format.size(); ++i) { | |||||
| auto input_items = support_format.input_format.at(i); | |||||
| auto output_items = support_format.output_format.at(i); | |||||
| std::string print_str = "["; | |||||
| for (const auto &input : input_items) { | |||||
| print_str.append(input); | |||||
| print_str.append(", "); | |||||
| } | |||||
| print_str.append("] -->"); | |||||
| for (const auto &output : output_items) { | |||||
| print_str.append(output); | |||||
| print_str.append(", "); | |||||
| } | |||||
| MS_LOG(INFO) << "Support format: " << print_str; | |||||
| } | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -1,492 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/ascend_backend_optimization.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <set> | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| #include "pre_activate/ascend/ir_fission/bn_split.h" | |||||
| #include "pre_activate/ascend/ir_fission/bn_grad_split.h" | |||||
| #include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h" | |||||
| #include "pre_activate/ascend/ir_fission/batch_norm_bert_fission.h" | |||||
| #include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h" | |||||
| #include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h" | |||||
| #include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h" | |||||
| #include "pre_activate/pass/communication_op_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/square_sum_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/clip_by_value_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h" | |||||
| #include "pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h" | |||||
| #include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" | |||||
| #include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h" | |||||
| #include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h" | |||||
| #include "pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" | |||||
| #include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/refresh_parameter_format.h" | |||||
| #include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fission/transdata_split.h" | |||||
| #include "pre_activate/ascend/ir_fission/topk_split.h" | |||||
| #include "pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/mul_add_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" | |||||
| #include "pre_activate/ascend/ir_fusion/derelu_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h" | |||||
| #include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" | |||||
| #include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" | |||||
| #include "pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h" | |||||
| #include "pre_activate/ascend/format_type/insert_trans_op.h" | |||||
| #include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h" | |||||
| #include "pre_activate/pass/getitem_tuple.h" | |||||
| #include "pre_activate/pass/optimize_dependence.h" | |||||
| #include "pre_activate/pass/erase_visit_attr.h" | |||||
| #include "pre_activate/ascend/format_type/insert_cast.h" | |||||
| #include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" | |||||
| #include "pre_activate/pass/eliminate_redundant_op.h" | |||||
| #include "pre_activate/pass/common_subexpression_elimination.h" | |||||
| #include "pre_activate/pass/fuse_graph_kernel.h" | |||||
| #include "pre_activate/pass/fuse_basic.h" | |||||
| #include "pre_activate/pass/add_atomic_clean.h" | |||||
| #include "pre_activate/ascend/format_type/merge_cast_to_op.h" | |||||
| #include "pre_activate/ascend/format_type/check_consistency.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" | |||||
| #include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" | |||||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" | |||||
| #include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" | |||||
| #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" | |||||
| #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" | |||||
| #include "pre_activate/ascend/ir_fission/addn_fission.h" | |||||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" | |||||
| #include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" | |||||
| #include "pre_activate/ascend/ir_fission/split_fission.h" | |||||
| #include "pre_activate/ascend/format_type/modify_ops_attrs.h" | |||||
| #include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h" | |||||
| #include "utils/context/ms_context.h" | |||||
| #include "utils/config_manager.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| #include "debug/anf_ir_utils.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||||
| MS_EXCEPTION_IF_NULL(ir_fusion_pm); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormBertFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusionV2>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusionV3>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond3>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond4>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond1>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond2>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond3>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ClipByValueFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond1Fusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond2Fusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond3Fusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond4Fusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond1>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond2>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond3>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond4>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond5>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<DereluFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SplitFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||||
| } | |||||
| } // namespace | |||||
| void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||||
| auto data_layout_pm = std::make_shared<PassManager>("pynative_transop_pm"); | |||||
| data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>()); | |||||
| data_layout_pm->AddPass(std::make_shared<RunOpInsertTransData>()); | |||||
| data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); | |||||
| data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||||
| data_layout_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | |||||
| data_layout_pm->AddPass(std::make_shared<OptimizeDependence>()); | |||||
| data_layout_pm->AddPass(std::make_shared<TransDataSplit>()); | |||||
| data_layout_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||||
| optimizer->AddPassManager(data_layout_pm); | |||||
| (void)optimizer->Optimize(kernel_graph); | |||||
| kernel_graph->SetExecOrderByDefault(); | |||||
| } | |||||
| void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||||
| MS_EXCEPTION_IF_NULL(optimizer); | |||||
| auto common_process = std::make_shared<PassManager>("graph_kernel_common_process"); | |||||
| MS_EXCEPTION_IF_NULL(common_process); | |||||
| common_process->AddPass(std::make_shared<ModifyOpAttrs>()); | |||||
| common_process->AddPass(std::make_shared<RemoveNoUseReshapeOp>()); | |||||
| optimizer->AddPassManager(common_process); | |||||
| (void)optimizer->Optimize(kernel_graph); | |||||
| kernel_graph->SetExecOrderByDefault(); | |||||
| } | |||||
| void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||||
| auto data_layout_pm = std::make_shared<PassManager>("transop_pm"); | |||||
| data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>()); | |||||
| data_layout_pm->AddPass(std::make_shared<InsertTransOp>()); | |||||
| data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); | |||||
| data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||||
| data_layout_pm->AddPass(std::make_shared<RemoveReshapePair>()); | |||||
| data_layout_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | |||||
| data_layout_pm->AddPass(std::make_shared<OptimizeDependence>()); | |||||
| data_layout_pm->AddPass(std::make_shared<TransDataSplit>()); | |||||
| data_layout_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||||
| optimizer->AddPassManager(data_layout_pm); | |||||
| (void)optimizer->Optimize(kernel_graph); | |||||
| kernel_graph->SetExecOrderByDefault(); | |||||
| } | |||||
| void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||||
| auto mixed_precision_pm = std::make_shared<PassManager>("cast_pm"); | |||||
| mixed_precision_pm->AddPass(std::make_shared<InsertCast>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>()); | |||||
| optimizer->AddPassManager(mixed_precision_pm); | |||||
| (void)optimizer->Optimize(kernel_graph); | |||||
| kernel_graph->SetExecOrderByDefault(); | |||||
| } | |||||
| void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||||
| if (save_graphs_path.empty()) { | |||||
| save_graphs_path = "."; | |||||
| } | |||||
| if (save_graphs) { | |||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before" + "_graph_" + | |||||
| std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph); | |||||
| DumpIRProto(kernel_graph, "before_hwopt_" + std::to_string(kernel_graph->graph_id())); | |||||
| } | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||||
| auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | |||||
| if (context_ptr->ir_fusion_flag()) { | |||||
| AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); | |||||
| } | |||||
| if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForGetNext>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||||
| } | |||||
| ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForHcclOp>()); | |||||
| optimizer->AddPassManager(ir_fusion_pm); | |||||
| (void)optimizer->Optimize(kernel_graph); | |||||
| kernel_graph->SetExecOrderByDefault(); | |||||
| if (save_graphs) { | |||||
| std::string file_path = | |||||
| save_graphs_path + "/" + "hwopt_d_ir_fusion_after" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph); | |||||
| } | |||||
| } | |||||
| void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (!context_ptr->ir_fusion_flag()) { | |||||
| MS_LOG(INFO) << "IRFusion is not enable, skip"; | |||||
| return; | |||||
| } | |||||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||||
| if (save_graphs_path.empty()) { | |||||
| save_graphs_path = "."; | |||||
| } | |||||
| if (save_graphs) { | |||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before.ir"; | |||||
| DumpIR(file_path, kernel_graph); | |||||
| } | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||||
| auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>()); | |||||
| optimizer->AddPassManager(ir_fusion_pm); | |||||
| (void)optimizer->Optimize(kernel_graph); | |||||
| kernel_graph->SetExecOrderByDefault(); | |||||
| if (save_graphs) { | |||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after.ir"; | |||||
| DumpIR(file_path, kernel_graph); | |||||
| } | |||||
| } | |||||
| void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||||
| if (save_graphs_path.empty()) { | |||||
| save_graphs_path = "."; | |||||
| } | |||||
| if (save_graphs) { | |||||
| std::string file_path = | |||||
| save_graphs_path + "/" + "hwopt_d_before" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph); | |||||
| } | |||||
| // data layout optimization | |||||
| AscendDataLayout(kernel_graph); | |||||
| // mixed precision optimization | |||||
| AscendMixPrecision(kernel_graph); | |||||
| // other optimization | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||||
| auto other_pm = std::make_shared<PassManager>("other_pm"); | |||||
| other_pm->AddPass(std::make_shared<AllReduceFusion>()); | |||||
| other_pm->AddPass(std::make_shared<AllGatherFusion>()); | |||||
| other_pm->AddPass(std::make_shared<ReduceScatterFusion>()); | |||||
| other_pm->AddPass(std::make_shared<BroadcastFusion>()); | |||||
| other_pm->AddPass(std::make_shared<ParameterTransOpFusion>()); | |||||
| other_pm->AddPass(std::make_shared<RefreshParameterFormat>()); | |||||
| optimizer->AddPassManager(other_pm); | |||||
| (void)optimizer->Optimize(kernel_graph); | |||||
| kernel_graph->SetExecOrderByDefault(); | |||||
| // buffer fusion | |||||
| AscendBackendUBFusionOptimization(kernel_graph); | |||||
| // other2 optimization | |||||
| auto optimizer2 = std::make_shared<GraphOptimizer>(); | |||||
| auto other2_pm = std::make_shared<PassManager>("other2_pm"); | |||||
| other2_pm->AddPass(std::make_shared<GetitemTuple>()); | |||||
| other2_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||||
| if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { | |||||
| other2_pm->AddPass(std::make_shared<GetnextMemcpyElimination>()); | |||||
| } | |||||
| other2_pm->AddPass(std::make_shared<CheckConsistency>()); | |||||
| optimizer2->AddPassManager(other2_pm); | |||||
| (void)optimizer2->Optimize(kernel_graph); | |||||
| kernel_graph->SetExecOrderByDefault(); | |||||
| if (save_graphs) { | |||||
| std::string file_path = | |||||
| save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph, true); | |||||
| DumpIRProto(kernel_graph, "after_hwopt"); | |||||
| kernel_graph->DumpFuncGraph("hwopt_d_end"); | |||||
| } | |||||
| } | |||||
| void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||||
| bool is_before_kernel_select) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (!(context_ptr->enable_graph_kernel())) { | |||||
| return; | |||||
| } | |||||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||||
| if (save_graphs_path.empty()) { | |||||
| save_graphs_path = "."; | |||||
| } | |||||
| if (save_graphs) { | |||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_before_graph_" + | |||||
| std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + | |||||
| ".ir"; | |||||
| DumpIR(file_path, kernel_graph); | |||||
| } | |||||
| // Fuse graph kernels with basic ops | |||||
| FuseGraphKernel(kernel_graph, is_before_kernel_select); | |||||
| if (save_graphs) { | |||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" + | |||||
| std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + | |||||
| ".ir"; | |||||
| DumpIR(file_path, kernel_graph, true); | |||||
| } | |||||
| } | |||||
| void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||||
| bool is_before_kernel_select) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (!(context_ptr->enable_graph_kernel())) { | |||||
| return; | |||||
| } | |||||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||||
| if (save_graphs_path.empty()) { | |||||
| save_graphs_path = "."; | |||||
| } | |||||
| if (save_graphs) { | |||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_before_graph_" + | |||||
| std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + | |||||
| ".ir"; | |||||
| DumpIR(file_path, kernel_graph, true); | |||||
| } | |||||
| // Fuse basic ops with basic ops | |||||
| FuseBasic(kernel_graph, is_before_kernel_select); | |||||
| if (save_graphs) { | |||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_end_graph_" + | |||||
| std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + | |||||
| ".ir"; | |||||
| DumpIR(file_path, kernel_graph, true); | |||||
| } | |||||
| } | |||||
| void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (!(context_ptr->enable_graph_kernel())) { | |||||
| return; | |||||
| } | |||||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||||
| if (save_graphs_path.empty()) { | |||||
| save_graphs_path = "."; | |||||
| } | |||||
| if (save_graphs) { | |||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_add_atomic_clean_before" + "_graph_" + | |||||
| std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph); | |||||
| } | |||||
| AddAtomicClean(kernel_graph); | |||||
| if (save_graphs) { | |||||
| std::string file_path = | |||||
| save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph, true); | |||||
| } | |||||
| } | |||||
| void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (!context_ptr->ir_fusion_flag()) { | |||||
| MS_LOG(INFO) << "UBFusion is not enable, skip"; | |||||
| return; | |||||
| } | |||||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||||
| if (save_graphs_path.empty()) { | |||||
| save_graphs_path = "."; | |||||
| } | |||||
| if (save_graphs) { | |||||
| std::string file_path = | |||||
| save_graphs_path + "/hwopt_d_ub_fusion_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph); | |||||
| } | |||||
| auto fusion_id_allocator = std::make_shared<FusionIdAllocator>(); | |||||
| MS_EXCEPTION_IF_NULL(fusion_id_allocator); | |||||
| fusion_id_allocator->Init(); | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||||
| auto ub_fusion_pm = std::make_shared<PassManager>("ub_fusion_pm"); | |||||
| ub_fusion_pm->AddPass(std::make_shared<Conv2DBackpropEltwiseEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<Conv2DBackpropEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<ConvBnReduceFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<ConvSingleInFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<ConvDoubleInFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<ReduceEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<SegmentEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<MultiOutputFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<UbPatternFusion>()); | |||||
| optimizer->AddPassManager(ub_fusion_pm); | |||||
| (void)optimizer->Optimize(kernel_graph); | |||||
| kernel_graph->SetExecOrderByDefault(); | |||||
| if (save_graphs) { | |||||
| std::string file_path = | |||||
| save_graphs_path + "/hwopt_d_ub_fusion_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph); | |||||
| } | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -1,226 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "kernel/oplib/oplib.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "session/kernel_graph.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { | |||||
| session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); | |||||
| AnfNodePtr cur_node = kernel_with_index.first; | |||||
| size_t cur_out_index = kernel_with_index.second; | |||||
| MS_EXCEPTION_IF_NULL(cur_node); | |||||
| if (cur_node->isa<CNode>()) { | |||||
| auto cnode = cur_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| std::string op_name = AnfAlgo::GetCNodeName(cnode); | |||||
| auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); | |||||
| // deal ref op | |||||
| if (op_info != nullptr && op_info->is_ref()) { | |||||
| auto ref_infos = op_info->ref_infos(); | |||||
| if (ref_infos.count(cur_out_index) != 0) { | |||||
| auto in_index = ref_infos.at(cur_out_index); | |||||
| if (in_index > cnode->inputs().size()) { | |||||
| MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size() | |||||
| << ", ref info is " << cur_out_index; | |||||
| } | |||||
| AnfNodePtr next_node = cnode->input(in_index + 1); | |||||
| return FindRefOriginNode(next_node); | |||||
| } | |||||
| } | |||||
| // deal special (trans,cast,reshape) op | |||||
| if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() || | |||||
| op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) { | |||||
| AnfNodePtr next_node = cnode->input(1); | |||||
| return FindRefOriginNode(next_node); | |||||
| } | |||||
| } | |||||
| return kernel_with_index; | |||||
| } | |||||
| void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, | |||||
| const AnfNodePtr &final_node, size_t final_index, | |||||
| const session::KernelWithIndex &origin_pair) { | |||||
| // record the ref_pair | |||||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| // if the final node is get item, means no trans or cast op is added, the final node is itself | |||||
| // so add the pair for itself, because the get item will removed later | |||||
| auto final_ref = (final_node == get_item ? cnode : final_node); | |||||
| session::AnfWithOutIndex final_pair = std::make_pair(final_ref, final_index); | |||||
| if (kernel_graph->IsInRefOutputMap(final_pair)) { | |||||
| MS_LOG(EXCEPTION) << "ref_pair is already in ref map, node is " << final_ref->DebugString() << ", index is " | |||||
| << final_index; | |||||
| } | |||||
| MS_LOG(DEBUG) << "Add Ref pair, final {node ptr " << final_pair.first.get() << " , info is " | |||||
| << final_pair.first->DebugString() << " , index is " << final_pair.second << "}, origin {node ptr " | |||||
| << origin_pair.first.get() << ", info is " << origin_pair.first->DebugString() << " : index " | |||||
| << origin_pair.second << "}"; | |||||
| kernel_graph->AddRefCorrespondPairs(final_pair, origin_pair); | |||||
| } | |||||
| // if get_item is nullptr, the additional node will link to the cnode | |||||
| // else the additional node will link to the get_item node (the get_item node link to cnode) | |||||
| AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, | |||||
| size_t input_index, const AnfNodePtr &get_item) { | |||||
| AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); | |||||
| size_t final_index = output_index; | |||||
| AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); | |||||
| session::KernelWithIndex origin_pair; | |||||
| origin_pair = FindRefOriginNode(input_node); | |||||
| MS_EXCEPTION_IF_NULL(origin_pair.first); | |||||
| if (!origin_pair.first->isa<Parameter>()) { | |||||
| MS_LOG(WARNING) << "ref op origin node is not parameter"; | |||||
| } | |||||
| MS_LOG(DEBUG) << "DealRefTransAndCast the node input index " << input_index << ", find origin op is " | |||||
| << origin_pair.first->DebugString() << ", index is " << origin_pair.second; | |||||
| auto origin_format = AnfAlgo::GetOutputFormat(origin_pair.first, origin_pair.second); | |||||
| auto origin_type = AnfAlgo::GetOutputDeviceDataType(origin_pair.first, origin_pair.second); | |||||
| auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index); | |||||
| auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index); | |||||
| auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, output_index); | |||||
| // insert trans | |||||
| if (origin_format != cur_format && cur_shape.size() > 1) { | |||||
| auto kernel_select = std::make_shared<KernelSelect>(); | |||||
| final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); | |||||
| RefreshKernelBuildInfo(cur_format, origin_format, final_node); | |||||
| final_index = 0; | |||||
| MS_EXCEPTION_IF_NULL(final_node); | |||||
| MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); | |||||
| } | |||||
| // insert cast | |||||
| if (origin_type != cur_type) { | |||||
| final_node = | |||||
| AddCastOpNodeToGraph(func_graph, final_node, origin_format, cur_type, origin_type, cur_shape, cur_type); | |||||
| MS_EXCEPTION_IF_NULL(final_node); | |||||
| final_node->set_scope(cnode->scope()); | |||||
| final_index = 0; | |||||
| MS_LOG(INFO) << "DealRefTransAndCast add cast op, op debug info is " << final_node->DebugString(); | |||||
| } | |||||
| // add ref pair | |||||
| AddRefPairToKernelGraph(func_graph, cnode, get_item, final_node, final_index, origin_pair); | |||||
| // insert depend | |||||
| if (origin_format != cur_format || origin_type != cur_type) { | |||||
| std::vector<AnfNodePtr> depend_nodes{NewValueNode(prim::kPrimDepend), cnode, final_node}; | |||||
| final_node = func_graph->NewCNode(depend_nodes); | |||||
| MS_LOG(INFO) << "DealRefTransAndCast add denpend, op debug info is " << final_node->DebugString(); | |||||
| } | |||||
| return final_node; | |||||
| } | |||||
| AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const std::shared_ptr<kernel::OpInfo> &op_info) { | |||||
| MS_EXCEPTION_IF_NULL(op_info); | |||||
| auto ref_infos = op_info->ref_infos(); | |||||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||||
| AbstractBasePtrList abstract_list; | |||||
| make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | |||||
| AnfNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index); | |||||
| // deal with ref output | |||||
| if (ref_infos.count(output_index) != 0) { | |||||
| auto input_index = ref_infos.at(output_index); | |||||
| final_node = AddAdditionalToRefOutput(func_graph, cnode, output_index, input_index, final_node); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(final_node); | |||||
| abstract_list.push_back(final_node->abstract()); | |||||
| make_tuple_inputs.push_back(final_node); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||||
| make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||||
| return make_tuple; | |||||
| } | |||||
| AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const std::shared_ptr<kernel::OpInfo> &op_info) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_EXCEPTION_IF_NULL(op_info); | |||||
| auto ref_infos = op_info->ref_infos(); | |||||
| for (const auto &ref_info : ref_infos) { | |||||
| if (ref_info.second > cnode->inputs().size()) { | |||||
| MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size() << ", ref info is " | |||||
| << ref_info.second; | |||||
| } | |||||
| return AddAdditionalToRefOutput(func_graph, cnode, ref_info.first, ref_info.second, nullptr); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef DealRefTransAndCast::DefinePattern() const { | |||||
| VarPtr V = std::make_shared<CondVar>(UnVisited); | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| return VectorRef({V, Xs}); | |||||
| } | |||||
| void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||||
| if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { | |||||
| auto input_size = AnfAlgo::GetInputTensorNum(cnode); | |||||
| for (size_t i = 0; i < input_size; ++i) { | |||||
| auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i); | |||||
| auto input_node = input_node_with_index.first; | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| MS_LOG(INFO) << "origin node:" << input_node->fullname_with_scope(); | |||||
| AddRefPairToKernelGraph(func_graph, cnode, nullptr, cnode, i, input_node_with_index); | |||||
| } | |||||
| } | |||||
| } | |||||
| const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| if (node == nullptr || !node->isa<CNode>()) { | |||||
| return nullptr; | |||||
| } | |||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!AnfAlgo::IsRealCNodeKernel(cnode)) { | |||||
| return nullptr; | |||||
| } | |||||
| DealBroadCastAsRef(graph, cnode); | |||||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||||
| auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); | |||||
| if (op_info == nullptr || !op_info->is_ref()) { | |||||
| return nullptr; | |||||
| } | |||||
| if (op_info->is_ref()) { | |||||
| auto type = cnode->Type(); | |||||
| MS_EXCEPTION_IF_NULL(type); | |||||
| if (!type->isa<Tuple>()) { | |||||
| return DealRefSigleOutput(graph, cnode, op_info); | |||||
| } else { | |||||
| return DealRefForMultipleOutput(graph, cnode, op_info); | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -1,64 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/pass/convert_const_input_to_attr.h" | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <memory> | |||||
| #include "pre_activate/pass/const_input_to_attr_registry.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| #include "utils/utils.h" | |||||
| #include "utils/context/ms_context.h" | |||||
| #include "operator/ops.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "kernel/common_utils.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> todos; | |||||
| if (AnfAlgo::IsGraphKernel(node)) { | |||||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||||
| MS_EXCEPTION_IF_NULL(sub_graph); | |||||
| kernel::GetValidKernelNodes(sub_graph, &todos); | |||||
| } else { | |||||
| todos.push_back(node); | |||||
| } | |||||
| for (auto &t : todos) { | |||||
| CNodePtr cnode = t->cast<CNodePtr>(); | |||||
| ConstInputToAttrInfoRegister reg; | |||||
| if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), ®)) { | |||||
| continue; | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() || | |||||
| AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) { | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| ConstInputToAttr(cnode, reg.GetConstInputAttrInfo()); | |||||
| } | |||||
| return node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||