diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc deleted file mode 100644 index 99e792216f..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc +++ /dev/null @@ -1,312 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/aicpu/aicpu_kernel_build.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include "device/kernel_runtime.h" -#include "kernel/aicpu/aicpu_kernel_mod.h" -#include "kernel/akg/akg_kernel_build.h" -#include "proto/tensor.pb.h" -#include "proto/tensor_shape.pb.h" -#include "proto/attr.pb.h" -#include "proto/node_def.pb.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "kernel/aicpu/aicpu_util.h" -#include "session/kernel_graph.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace kernel { -using FNodeAttrHandle = std::function &anf_node, mindspore::NodeDef *proto)>; - -bool SetIOIputSize(const std::shared_ptr &anf_node, const size_t &input_num, - std::vector *input_size_list) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(input_size_list); - for (size_t i = 0; i < input_num; i++) { - std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i); - if (AnfAlgo::GetInputDeviceDataType(anf_node, i) == kObjectTypeString) { - if (!anf_node->isa()) { - MS_LOG(EXCEPTION) << "anf_node is not CNode."; - } - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() < (i + 1)) { - MS_LOG(ERROR) << "cnode inputs size " << cnode->inputs().size() << " is smaller than " << i + 1; - return false; - } - auto input_node = cnode->inputs()[i + 1]; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa()) { - auto value_ptr = GetValueNode(input_node); - auto value = GetValue(value_ptr); - input_size_list->push_back(value.size()); - } - } else { - auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i)); - MS_EXCEPTION_IF_NULL(type_ptr); - int64_t size_i = 1; - for (size_t j = 0; j < shape_i.size(); j++) { - size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); - } - size_t type_byte = GetTypeByte(type_ptr); - if (type_byte == 0) { - return false; - } - size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); - input_size_list->push_back(LongToSize(size_i)); - } - } - return true; -} - -bool SetIOSize(const std::shared_ptr &anf_node, const std::shared_ptr &kernel_mod_ptr) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - std::vector input_size_list; - std::vector output_size_list; - size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); - size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); - - if (!SetIOIputSize(anf_node, input_num, &input_size_list)) { - return false; - } - kernel_mod_ptr->SetInputSizeList(input_size_list); - - for (size_t i = 0; i < output_num; i++) { - std::vector shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); - TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i)); - MS_EXCEPTION_IF_NULL(type_ptr); - int64_t size_i = 1; - for (size_t j = 0; j < shape_i.size(); j++) { - size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); - } - size_t type_byte = GetTypeByte(type_ptr); - if (type_byte == 0) { - return false; - } - size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); - output_size_list.push_back(LongToSize(size_i)); - } - kernel_mod_ptr->SetOutputSizeList(output_size_list); - return true; -} - -void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value, - ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) { - MS_EXCEPTION_IF_NULL(node_attr); - MS_EXCEPTION_IF_NULL(value); - if (type == "int") { - auto attr_value = GetValue(value); - (*node_attr)[attr_name].set_i(attr_value); - } else if (type == "str") { - auto attr_value = GetValue(value); - (*node_attr)[attr_name].set_s(attr_value); - } else if (type == "bool") { - auto attr_value = GetValue(value); - (*node_attr)[attr_name].set_b(attr_value); - } else if (type == "float") { - auto attr_value = GetValue(value); - (*node_attr)[attr_name].set_f(attr_value); - } else if (type == "listInt") { - std::vector attr_value; - auto value_type = value->type(); - MS_EXCEPTION_IF_NULL(value_type); - auto value_type_str = value_type->ToString(); - if (value_type_str == "Int32") { - int data = GetValue(value); - attr_value.push_back(data); - } else { - attr_value = GetValue>(value); - } - mindspore::AttrValue input_shape_attr; - mindspore::AttrValue_ArrayValue *input_shape_attr_list = input_shape_attr.mutable_array(); - MS_EXCEPTION_IF_NULL(input_shape_attr_list); - for (const auto shape : attr_value) { - input_shape_attr_list->add_i(shape); - } - (*node_attr)[attr_name] = input_shape_attr; - } else { - MS_LOG(EXCEPTION) << "type: " << type << "not support"; - } -} - -void SetNodeAttr(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(proto); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (op_name == kInitDataSetQueue) { - op_name = kInitData; - } - if (op_name == kPrint) { - return; - } - - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU); - MS_EXCEPTION_IF_NULL(op_info_ptr); - auto attrs_ptr = op_info_ptr->attrs_ptr(); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); - for (const auto &attr_ptr : attrs_ptr) { - MS_EXCEPTION_IF_NULL(attr_ptr); - std::string attr_name = attr_ptr->name(); - auto value = primitive->GetAttr(attr_name); - if (value != nullptr) { - if (attr_name == kQueueName || attr_name == kSharedName) { - attr_name = kChannelName; - } else if (attr_name == kSeed0) { - attr_name = kSeed; - } else if (attr_name == kSeed1) { - attr_name = kSeed2; - } - std::string type = attr_ptr->type(); - ParseAttrValue(type, attr_name, value, node_attr); - } - } - MS_LOG(INFO) << "Set node attr end!"; -} - -void SetNodeInputs(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { - MS_EXCEPTION_IF_NULL(proto); - MS_EXCEPTION_IF_NULL(anf_node); - size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); - if (input_num == 0) { - MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input."; - return; - } - - for (size_t input_index = 0; input_index < input_num; input_index++) { - ::mindspore::Tensor *node_inputs = proto->add_inputs(); - MS_EXCEPTION_IF_NULL(node_inputs); - TypeId input_type = AnfAlgo::GetInputDeviceDataType(anf_node, input_index); - std::vector input_shape; - int32_t input_data_type; - if (input_type == kObjectTypeString) { - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input_node = cnode->inputs()[input_index + 1]; - auto value_ptr = GetValueNode(input_node); - auto value = GetValue(value_ptr); - input_shape.push_back(1); - input_shape.push_back(value.size()); - input_data_type = AicpuOpUtil::MsTypeToProtoType(kTypeUnknown); - } else { - input_shape = AnfAlgo::GetInputDeviceShape(anf_node, input_index); - input_data_type = AicpuOpUtil::MsTypeToProtoType(input_type); - } - - mindspore::TensorShape *tensorShape = node_inputs->mutable_tensor_shape(); - for (auto item : input_shape) { - mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); - dim->set_size((::google::protobuf::int64)item); - } - node_inputs->set_tensor_type((mindspore::DataType)input_data_type); - node_inputs->set_mem_device("HBM"); - } -} - -void SetNodeOutputs(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { - MS_EXCEPTION_IF_NULL(proto); - MS_EXCEPTION_IF_NULL(anf_node); - size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); - if (output_num == 0) { - MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have output. "; - return; - } - - for (size_t output_index = 0; output_index < output_num; output_index++) { - ::mindspore::Tensor *node_outputs = proto->add_outputs(); - MS_EXCEPTION_IF_NULL(node_outputs); - std::vector output_shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index); - mindspore::TensorShape *tensorShape = node_outputs->mutable_tensor_shape(); - MS_EXCEPTION_IF_NULL(tensorShape); - for (auto item : output_shape) { - mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); - MS_EXCEPTION_IF_NULL(dim); - dim->set_size((::google::protobuf::int64)item); - } - TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, output_index); - int32_t output_data_type = AicpuOpUtil::MsTypeToProtoType(output_type); - node_outputs->set_tensor_type((mindspore::DataType)output_data_type); - node_outputs->set_mem_device("HBM"); - } -} - -void SetNodedefProto(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(proto); - MS_LOG(INFO) << "SetNodedefProto entry"; - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (op_name == kInitDataSetQueue) { - op_name = kInitData; - } - // set op name - proto->set_op(op_name); - // set inputs tensor - SetNodeInputs(anf_node, proto); - // set outputs tensor - SetNodeOutputs(anf_node, proto); - // set node attr - SetNodeAttr(anf_node, proto); - MS_LOG(INFO) << "SetNodedefProto end!"; -} - -bool CreateNodeDefBytes(const std::shared_ptr &anf_node, - const std::shared_ptr &kernel_mod_ptr) { - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "CreateNodeDefBytes entry"; - - mindspore::NodeDef proto; - SetNodedefProto(anf_node, &proto); - std::string nodeDefStr; - if (!proto.SerializeToString(&nodeDefStr)) { - MS_LOG(ERROR) << "Serialize nodeDef to string failed."; - return false; - } - kernel_mod_ptr->SetNodeDef(nodeDefStr); - MS_LOG(INFO) << "CreateNodeDefBytes end!"; - return true; -} - -KernelModPtr AicpuOpBuild(const std::shared_ptr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (op_name == kInitDataSetQueue) { - op_name = kInitData; - } - auto kernel_mod_ptr = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - kernel_mod_ptr->SetAnfNode(anf_node); - kernel_mod_ptr->SetNodeName(op_name); - if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) { - MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!"; - } - if (!SetIOSize(anf_node, kernel_mod_ptr)) { - MS_LOG(EXCEPTION) << "Set input output size list failed."; - } - return kernel_mod_ptr; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_util.h b/mindspore/ccsrc/kernel/aicpu/aicpu_util.h deleted file mode 100644 index f2092abbe2..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_util.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ - -#include -#include -#include -#include -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -constexpr auto kInitDataSetQueue = "InitDataSetQueue"; -constexpr auto kInitData = "InitData"; -constexpr auto kGetNext = "GetNext"; -constexpr auto kPrint = "Print"; -constexpr auto kPack = "Pack"; - -constexpr auto kOutputTypes = "output_types"; -constexpr auto kOutputShapes = "output_shapes"; -constexpr auto kChannelName = "channel_name"; -constexpr auto kSharedName = "shared_name"; -constexpr auto kShapes = "shapes"; -constexpr auto kTypes = "types"; -constexpr auto kQueueName = "queue_name"; -constexpr auto kSeed = "seed"; -constexpr auto kSeed0 = "Seed0"; -constexpr auto kSeed1 = "Seed1"; -constexpr auto kSeed2 = "seed2"; -constexpr auto kTopK = "TopK"; -constexpr auto kTopKV2 = "TopKV2"; - -struct AicpuParamHead { - uint32_t length; // Total length: include cunstom message - uint32_t ioAddrNum; // Input and output address number - uint32_t extInfoLength; // extInfo struct Length - uint64_t extInfoAddr; // extInfo address -} __attribute__((packed)); - -class AicpuOpUtil { - public: - static int MsTypeToProtoType(TypeId ms_type); - - private: - // kernel id - static uint64_t KernelId_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc deleted file mode 100644 index 9951321f5e..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ /dev/null @@ -1,622 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" -#include -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/oplib.h" -#include "kernel/tbe/tbe_kernel_build.h" -#include "nlohmann/json.hpp" -#include "utils/context/ms_context.h" -#include "kernel/tbe/tbe_python_funcs.h" -#include "pre_activate/common/helper.h" -#include "kernel/tbe/tbe_convert_utils.h" -#include "parallel/ops_info/ops_utils.h" -#include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" -#include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" -#include "kernel/tbe/tbe_kernel_select/common_utils.h" - -namespace mindspore { -namespace kernel { -constexpr auto kName = "name"; -constexpr auto kDtype = "dtype"; -constexpr auto kFormat = "format"; -constexpr auto kPrefixInput = "input"; -constexpr auto kPrefixOutput = "output"; -constexpr char kParamTypeDynamic[] = "dynamic"; -constexpr char kParamTypeRequre[] = "required"; -constexpr char kParamTypeOptional[] = "optional"; -void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { - auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list); - tbe_selecter.TbeMetadataInfoEx(); -} - -TbeKernelSelect::TbeKernelSelect(CNodePtr kernel_node, std::vector> *kernel_info_list) - : cnode_ptr_(std::move(kernel_node)), kernel_info_list_(kernel_info_list) {} - -void TbeKernelSelect::TbeMetadataInfoEx() { - MS_EXCEPTION_IF_NULL(cnode_ptr_); - MS_EXCEPTION_IF_NULL(kernel_info_list_); - node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_); - auto op_info_ptr = OpLib::FindOp(node_name_, kTBE); - if (!op_info_ptr) { - MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_; - return; - } - MS_LOG(INFO) << "Start to tbe metadata info. node type: " << node_name_ - << ", node name: " << cnode_ptr_->fullname_with_scope(); - OpPattern pattern = op_info_ptr->op_pattern(); - if (pattern == kCommonPattern) { - GetCommonPatternKernelInfo(*op_info_ptr); - } else if (pattern == kDynamicFormatPattern) { - GetDynamicFormatPatternKernelInfo(*op_info_ptr); - } else if (pattern == kFormatAgnosticPattern) { - GetAgnosticPatternKernelInfo(*op_info_ptr); - } else if (pattern == kBroadcastPattern) { - GetBroadcastPatternKernelInfo(*op_info_ptr); - } else if (pattern == kReducePattern) { - GetReducePatternKernelInfo(*op_info_ptr); - } else { - MS_LOG(INFO) << "Warning: op pattern is invailed."; - } - // check support - FilterInVaildKernelInfo(); - MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; -} - -void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { - MS_LOG(INFO) << "start."; - // get dynamic inputs - auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); - MS_EXCEPTION_IF_NULL(primitive); - std::vector dyn_input_sizes; - if (primitive->HasAttr(kAttrDynInputSizes)) { - dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); - } - // get real input/output num - size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); - const auto inputs_info = op_info.inputs_ptr(); - size_t real_output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); - const auto outputs_info = op_info.outputs_ptr(); - if (inputs_info.empty() && outputs_info.empty()) { - MS_LOG(EXCEPTION) << "op info input & output is null, please check."; - } - // create kernel build info from opinfo - size_t kernel_build_info_num = - inputs_info.empty() ? outputs_info[0]->dtypes().size() : inputs_info[0]->dtypes().size(); - for (size_t kernel_build_info_index = 0; kernel_build_info_index < kernel_build_info_num; ++kernel_build_info_index) { - auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); - SetTbeBuildCommonInfo(op_info, &builder); - std::vector inputs_format; - std::vector inputs_device_type; - std::vector> inputs_reshape_type; - // input - if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes, - &inputs_format, &inputs_device_type, &inputs_reshape_type)) { - break; - } - builder.SetInputsDeviceType(inputs_device_type); - builder.SetInputsFormat(inputs_format); - builder.SetInputReshapeType(inputs_reshape_type); - // output - std::vector outputs_format; - std::vector outputs_device_type; - std::vector> outputs_reshape_type; - if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes, - &outputs_format, &outputs_device_type, &outputs_reshape_type)) { - break; - } - builder.SetOutputsDeviceType(outputs_device_type); - builder.SetOutputsFormat(outputs_format); - builder.SetOutputReshapeType(outputs_reshape_type); - kernel_info_list_->emplace_back(builder.Build()); - } - MS_LOG(INFO) << "end."; -} - -void TbeKernelSelect::GetDynamicFormatPatternKernelInfo(const OpInfo &op_info) { - MS_LOG(INFO) << "start."; - // - OpInfo op_info_new; - CreateNewOpInfo(op_info, &op_info_new); - GetCommonPatternKernelInfo(op_info_new); - MS_LOG(INFO) << "end."; -} - -void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) { - MS_LOG(INFO) << "start."; - if (op_info.inputs_ptr().size() != 1) { - MS_LOG(EXCEPTION) << "AgnosticPattern only support one input."; - } - auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0); - if (kOpFormatList.find(format) == kOpFormatList.end()) { - MS_LOG(INFO) << "Got the unknown format " << format; - format = kOpFormat_DEFAULT; - } - SupportFormat support_format; - SupportFormatItem input_item; - SupportFormatItem output_item; - input_item.assign(op_info.inputs_ptr().size(), format); - output_item.assign(op_info.outputs_ptr().size(), format); - support_format.input_format.emplace_back(input_item); - support_format.output_format.emplace_back(output_item); - PrintSupportedFormat(support_format); - OpInfo op_info_new; - CreateNewOpInfo(op_info, support_format, &op_info_new); - GetCommonPatternKernelInfo(op_info_new); - MS_LOG(INFO) << "end."; -} - -void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) { - MS_LOG(INFO) << "start."; - auto broadcast_selecter = TbeKernelBroadCastSelecter(cnode_ptr_); - SupportFormat support_format; - broadcast_selecter.GetShapeInfo(&support_format); - if (!broadcast_selecter.IsBroadCastSupport5HD(&support_format)) { - MS_LOG(INFO) << "Node(" << node_name_ << ") does not support 5HD."; - } - if (!broadcast_selecter.IsBroadCastSupportFracZ(&support_format)) { - MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracZ."; - } - if (!broadcast_selecter.IsBroadCastSupportC1HWNCoC0(&support_format)) { - MS_LOG(INFO) << "Node(" << node_name_ << ") does not support C1HWNCoC0."; - } - if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) { - MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ."; - } - PrintSupportedFormat(support_format); - OpInfo op_info_new; - CreateNewOpInfo(op_info, support_format, &op_info_new); - GetCommonPatternKernelInfo(op_info_new); - MS_LOG(INFO) << "end."; -} - -void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) { - MS_LOG(INFO) << "start."; - auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_); - SupportFormat support_format; - reduce_selecter.GetShapeInfo(&support_format); - if (!reduce_selecter.IsReduceSupport5HD(&support_format)) { - MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support 5HD."; - } - if (reduce_selecter.IsReduceSupportFracZ(&support_format)) { - MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracZ."; - } - if (reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format)) { - MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support C1HWNCoC0."; - } - if (reduce_selecter.IsReduceSupportFracNZ(&support_format)) { - MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracNZ."; - } - PrintSupportedFormat(support_format); - OpInfo op_info_new; - CreateNewOpInfo(op_info, support_format, &op_info_new); - GetCommonPatternKernelInfo(op_info_new); - MS_LOG(INFO) << "end."; -} - -void TbeKernelSelect::FilterInVaildKernelInfo() { - if (kernel_info_list_->empty()) { - MS_LOG(INFO) << "Warning: get kernel build info failed."; - return; - } - auto kernel_build_info_iter = kernel_info_list_->begin(); - while (kernel_build_info_iter != kernel_info_list_->end()) { - if (!FilterInVaildShape(kernel_build_info_iter)) { - MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*kernel_build_info_iter)->ToString(); - kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); - continue; - } - if (!TbeCheckSupported(kernel_build_info_iter)) { - MS_LOG(INFO) << "Check support shape, filter item info: " << (*kernel_build_info_iter)->ToString(); - kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); - continue; - } - kernel_build_info_iter++; - } -} - -bool TbeKernelSelect::FilterInVaildShape( - const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { - MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); - auto kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); - for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); - auto format = kernel_build_info_inputs_format.at(i); - if (!IsShapeMatchFormat(shape, format)) { - MS_LOG(INFO) << "The " << i << "th input check failed."; - return false; - } - } - auto kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats(); - for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) { - auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j); - auto format = kernel_build_info_outputs_format.at(j); - if (!IsShapeMatchFormat(shape, format)) { - MS_LOG(INFO) << "The " << j << "th input check failed."; - return false; - } - } - return true; -} - -bool TbeKernelSelect::IsShapeMatchFormat(const std::vector &shape, const std::string &format) { - if (format == kOpFormat_DEFAULT) { - return true; - } - static std::set kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; - // if format is default, it remarkes support all format - if (kOpFormatList.find(format) == kOpFormatList.end()) { - MS_LOG(EXCEPTION) << "Got the unknown format " << format; - } - // server not support format with C04 suffix - if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) != - kServerNotSupportFormat.end()) { - MS_LOG(INFO) << "Warning: Server not support format with C04 suffix."; - return false; - } - // not support format: - // 1 NDHWC with shape size != 5 - // 2 FRAC_NZ with shape size < 2 - // 3 !NDHWC with shape size > 4 - if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) || - (format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) || - (format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) { - MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); - return false; - } - return true; -} - -bool TbeKernelSelect::TbeCheckSupported( - const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { - MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); - static const std::set kCheckSupportedOpType = {parallel::MATMUL, - parallel::BATCHMATMUL, - parallel::TOPK, - parallel::IN_TOPK, - parallel::PACK, - parallel::UNSORTEF_SEGMENT_MIND, - parallel::UNSORTEF_SEGMENT_PRODD, - parallel::CAST}; - auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_); - if (iter == kCheckSupportedOpType.end()) { - return true; - } - MS_LOG(INFO) << "Check support start."; - // replace kernel_info with current kernel info - auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_); - AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get()); - nlohmann::json kernel_json; - TbeKernelJsonCreator creator(CHECK_SUPPORTED); - bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); - if (!ret) { - MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed."; - } - ret = TbePythonFuncs::CheckSupported(kernel_json); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get()); - return ret; -} - -void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_info, - mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder *builder) { - MS_EXCEPTION_IF_NULL(builder); - builder->SetProcessor(AICORE); - std::string fusion_type = op_info.fusion_type(); - if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { - builder->SetFusionType(tbe::GetFusionType(fusion_type)); - } - builder->SetOpPattern(op_info.op_pattern()); - builder->SetKernelType(TBE_KERNEL); -} - -bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, - const std::vector> &ios_info, - const std::vector &dyn_input_sizes, std::vector *formats, - std::vector *device_types, std::vector> *reshape_types) { - MS_EXCEPTION_IF_NULL(formats); - MS_EXCEPTION_IF_NULL(device_types); - MS_EXCEPTION_IF_NULL(reshape_types); - size_t dynamic_input_index = 0; - size_t real_io_tensor_index = 0; - size_t io_info_index = 0; - size_t io_info_num = ios_info.size(); - for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) { - std::shared_ptr io_info_item = ios_info[io_info_index]; - auto kernel_build_info_dtype = io_info_item->dtypes().at(kernel_build_info_index); - std::string kernel_build_info_format; - if (!io_info_item->formats().empty()) { - kernel_build_info_format = io_info_item->formats().at(kernel_build_info_index); - } - std::string io_param_type = io_info_item->param_type(); - std::vector reshape_type; - StringToAxisVector(io_info_item->reshape_type(), &reshape_type); - if (io_param_type == kParamTypeDynamic) { - // dynamic io - if (is_input) { - if (dynamic_input_index >= dyn_input_sizes.size()) { - MS_LOG(EXCEPTION) << "dyn_input_sizes attr set error, dynamic_input_index: " << dynamic_input_index - << ", dyn_input_sizes size: " << dyn_input_sizes.size(); - } - int dynamic_input_size = dyn_input_sizes[dynamic_input_index]; - for (int i = 0; i < dynamic_input_size; ++i) { - device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); - formats->emplace_back(kernel_build_info_format); - reshape_types->emplace_back(reshape_type); - } - dynamic_input_index++; - real_io_tensor_index += dynamic_input_size; - } else { - if (ios_info.size() != 1) { - MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output."; - } - for (size_t i = 0; i < real_io_tensor_num; ++i) { - device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); - formats->emplace_back(kernel_build_info_format); - reshape_types->emplace_back(reshape_type); - } - real_io_tensor_index += real_io_tensor_num; - } - } else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) { - // requre or optional io - device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); - formats->emplace_back(kernel_build_info_format); - reshape_types->emplace_back(reshape_type); - real_io_tensor_index++; - } else { - MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type; - } - } - - if (io_info_index != io_info_num) { - MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num - << "), this node may has optional input/output."; - } - if (real_io_tensor_index != real_io_tensor_num) { - std::string io_type = is_input ? "inputs " : "outputs"; - MS_LOG(INFO) << node_name_ << "'s " << io_type << "op io info num: " << io_info_num - << ", real io tensor num:" << real_io_tensor_num << "real_io_tensor_index(" << real_io_tensor_index - << ") != real_io_tensor_num(" << real_io_tensor_num << ")"; - return false; - } - return true; -} - -void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec) { - MS_EXCEPTION_IF_NULL(reshape_type_vec); - for (const auto &c : reshape_type_str) { - switch (c) { - case 'N': - reshape_type_vec->push_back(kernel::N); - break; - case 'C': - reshape_type_vec->push_back(kernel::C); - break; - case 'H': - reshape_type_vec->push_back(kernel::H); - break; - case 'W': - reshape_type_vec->push_back(kernel::W); - break; - default: - MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; - } - } -} - -void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, - const std::vector> &support_format_item, size_t index, - mindspore::kernel::OpIOInfo *op_io_info_new) { - MS_EXCEPTION_IF_NULL(op_io_info_new); - op_io_info_new->set_index(op_io_info.index()); - op_io_info_new->set_name(op_io_info.name()); - op_io_info_new->set_param_type(op_io_info.param_type()); - op_io_info_new->set_need_compile(op_io_info.need_compile()); - op_io_info_new->set_reshape_type(op_io_info.reshape_type()); - op_io_info_new->set_shape(op_io_info.shape()); - // dtype - std::vector dtype_new; - auto dtype = op_io_info.dtypes(); - for (size_t i = 0; i < support_format_item.size(); ++i) { - dtype_new.insert(dtype_new.end(), dtype.begin(), dtype.end()); - } - op_io_info_new->set_dtypes(dtype_new); - // format - std::vector format_new; - for (const auto &formats : support_format_item) { - auto format = formats.at(index); - for (size_t j = 0; j < dtype.size(); ++j) { - format_new.emplace_back(format); - } - } - op_io_info_new->set_formats(format_new); -} - -std::vector TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) { - const std::map kDynamicFormatMap = { - {"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}}; - if (op_select_json_item.empty()) { - MS_LOG(EXCEPTION) << "Op select ret item is null."; - } - const char space = ' '; - const char sep = ','; - std::string op_select_tmp = op_select_json_item + ","; - std::vector ret; - auto begin = op_select_tmp.find_first_not_of(space, 0); - auto sep_pos = op_select_tmp.find(sep); - if (begin >= sep_pos) { - MS_LOG(EXCEPTION) << "Select ret json is error."; - } - while (sep_pos != std::string::npos) { - auto obj = op_select_tmp.substr(begin, sep_pos - begin); - if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) { - obj = kDynamicFormatMap.at(obj); - } - ret.emplace_back(obj); - begin = op_select_tmp.find_first_not_of(space, sep_pos + 1); - sep_pos = op_select_tmp.find(sep, begin); - } - return ret; -} - -std::string TbeKernelSelect::OpSelectFormat() { - nlohmann::json kernel_json; - std::string res_json_str; - TbeKernelJsonCreator creator(OP_SELECT_FORMAT); - bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); - if (!ret) { - MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed."; - } - res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json); - if (res_json_str.empty()) { - MS_LOG(EXCEPTION) << "op select format error."; - } - MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str; - return res_json_str; -} - -void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, const SupportFormat &support_format, - mindspore::kernel::OpInfo *op_info_new) { - MS_EXCEPTION_IF_NULL(op_info_new); - if (op_info.inputs_ptr().size() != support_format.input_format[0].size() || - op_info.outputs_ptr().size() != support_format.output_format[0].size()) { - MS_LOG(EXCEPTION) << "BroadCast input/output size not match, op info input size:" << op_info.inputs_ptr().size() - << ", input support size: " << support_format.input_format[0].size() - << ", op info output size: " << op_info.outputs_ptr().size() - << ", output support size: " << support_format.output_format[0].size(); - } - *op_info_new = op_info; - op_info_new->ClearInputs(); - op_info_new->ClearOutputs(); - for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { - auto input = op_info.inputs_ptr().at(i); - auto input_new = std::make_shared(); - CreateNewOpIOInfo(*input, support_format.input_format, i, input_new.get()); - op_info_new->add_inputs_ptr(input_new); - } - for (size_t j = 0; j < op_info.outputs_ptr().size(); ++j) { - auto output = op_info.outputs_ptr().at(j); - auto output_new = std::make_shared(); - CreateNewOpIOInfo(*output, support_format.output_format, j, output_new.get()); - op_info_new->add_outputs_ptr(output_new); - } -} - -struct SelectOpIOInfo { - std::string name; - std::vector dtypes; - std::vector formats; -}; - -void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, - mindspore::kernel::OpInfo *op_info_new) { - MS_EXCEPTION_IF_NULL(op_info_new); - auto op_seclect_json = OpSelectFormat(); - if (!op_seclect_json.empty()) { - nlohmann::json json_obj = nlohmann::json::parse(op_seclect_json); - if (!json_obj.is_object()) { - MS_LOG(EXCEPTION) << "JsonStr is not an object, the jsonStr is:" << op_seclect_json; - } - std::vector inputs; - std::vector outputs; - for (const auto &item : json_obj.items()) { - const std::string &item_name = item.key(); - bool is_input = (item_name.find(kPrefixInput) != std::string::npos); - bool is_output = (item_name.find(kPrefixOutput) != std::string::npos); - if (!is_input && !is_output) { - MS_LOG(EXCEPTION) << "op select ret json is error."; - } - if (is_input) { - SelectOpIOInfo select_input; - select_input.name = item.value().at(kName); - std::string input_dtype_item = item.value().at(kDtype); - select_input.dtypes = SplitStrToVec(input_dtype_item); - std::string input_format_item = item.value().at(kFormat); - select_input.formats = SplitStrToVec(input_format_item); - inputs.emplace_back(select_input); - } else if (is_output) { - SelectOpIOInfo select_output; - select_output.name = item.value().at(kName); - std::string input_dtype_item = item.value().at(kDtype); - select_output.dtypes = SplitStrToVec(input_dtype_item); - std::string input_format_item = item.value().at(kFormat); - select_output.formats = SplitStrToVec(input_format_item); - outputs.emplace_back(select_output); - } - } - - if (op_info.inputs_ptr().size() != inputs.size() || op_info.outputs_ptr().size() != outputs.size()) { - MS_LOG(EXCEPTION) << "select format input/output size not equal, please check register."; - } - - *op_info_new = op_info; - op_info_new->ClearInputs(); - op_info_new->ClearOutputs(); - for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { - auto input_new = std::make_shared(); - CreateNewOpIOInfo(*op_info.inputs_ptr().at(i), inputs.at(i).dtypes, inputs.at(i).formats, input_new.get()); - op_info_new->add_inputs_ptr(input_new); - } - for (size_t i = 0; i < op_info.outputs_ptr().size(); ++i) { - auto output_new = std::make_shared(); - CreateNewOpIOInfo(*op_info.outputs_ptr().at(i), outputs.at(i).dtypes, outputs.at(i).formats, output_new.get()); - op_info_new->add_outputs_ptr(output_new); - } - } -} - -void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, - const std::vector &support_dtype, - const std::vector &support_format, - mindspore::kernel::OpIOInfo *op_io_info_new) { - MS_EXCEPTION_IF_NULL(op_io_info_new); - op_io_info_new->set_index(op_io_info.index()); - op_io_info_new->set_name(op_io_info.name()); - op_io_info_new->set_param_type(op_io_info.param_type()); - op_io_info_new->set_need_compile(op_io_info.need_compile()); - op_io_info_new->set_reshape_type(op_io_info.reshape_type()); - op_io_info_new->set_shape(op_io_info.shape()); - // dtype && format - op_io_info_new->set_dtypes(support_dtype); - op_io_info_new->set_formats(support_format); -} - -void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) { - if (support_format.input_format.size() != support_format.output_format.size()) { - MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output(" - << support_format.output_format.size() << ") size not match."; - } - for (size_t i = 0; i < support_format.input_format.size(); ++i) { - auto input_items = support_format.input_format.at(i); - auto output_items = support_format.output_format.at(i); - std::string print_str = "["; - for (const auto &input : input_items) { - print_str.append(input); - print_str.append(", "); - } - print_str.append("] -->"); - for (const auto &output : output_items) { - print_str.append(output); - print_str.append(", "); - } - MS_LOG(INFO) << "Support format: " << print_str; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc deleted file mode 100644 index 48ce87629c..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ /dev/null @@ -1,492 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ascend_backend_optimization.h" -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fission/bn_split.h" -#include "pre_activate/ascend/ir_fission/bn_grad_split.h" -#include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h" -#include "pre_activate/ascend/ir_fission/batch_norm_bert_fission.h" -#include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h" -#include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h" -#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h" -#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h" -#include "pre_activate/pass/communication_op_fusion.h" -#include "pre_activate/ascend/ir_fusion/square_sum_fusion.h" -#include "pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" -#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" -#include "pre_activate/ascend/ir_fusion/clip_by_value_fusion.h" -#include "pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h" -#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h" -#include "pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h" -#include "pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h" -#include "pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h" -#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" -#include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" -#include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h" -#include "pre_activate/ascend/ir_fusion/refresh_parameter_format.h" -#include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h" -#include "pre_activate/ascend/ir_fission/transdata_split.h" -#include "pre_activate/ascend/ir_fission/topk_split.h" -#include "pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h" -#include "pre_activate/ascend/ir_fusion/mul_add_fusion.h" -#include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" -#include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" -#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" -#include "pre_activate/ascend/ir_fusion/derelu_fusion.h" -#include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h" -#include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" -#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" -#include "pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h" -#include "pre_activate/ascend/format_type/insert_trans_op.h" -#include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h" -#include "pre_activate/pass/getitem_tuple.h" -#include "pre_activate/pass/optimize_dependence.h" -#include "pre_activate/pass/erase_visit_attr.h" -#include "pre_activate/ascend/format_type/insert_cast.h" -#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" -#include "pre_activate/pass/eliminate_redundant_op.h" -#include "pre_activate/pass/common_subexpression_elimination.h" -#include "pre_activate/pass/fuse_graph_kernel.h" -#include "pre_activate/pass/fuse_basic.h" -#include "pre_activate/pass/add_atomic_clean.h" -#include "pre_activate/ascend/format_type/merge_cast_to_op.h" -#include "pre_activate/ascend/format_type/check_consistency.h" -#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" -#include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" -#include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" -#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" -#include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" -#include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" -#include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" -#include "pre_activate/ascend/ir_fission/addn_fission.h" -#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" -#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" -#include "pre_activate/ascend/ir_fission/split_fission.h" -#include "pre_activate/ascend/format_type/modify_ops_attrs.h" -#include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h" -#include "utils/context/ms_context.h" -#include "utils/config_manager.h" -#include "debug/anf_ir_dump.h" -#include "debug/anf_ir_utils.h" - -namespace mindspore { -namespace opt { -namespace { -void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { - MS_EXCEPTION_IF_NULL(ir_fusion_pm); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); -} -} // namespace - -void RunOpAscendDataLayout(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto optimizer = std::make_shared(); - auto data_layout_pm = std::make_shared("pynative_transop_pm"); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(data_layout_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - -void AscendGraphKernelCommonProcess(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto optimizer = std::make_shared(); - MS_EXCEPTION_IF_NULL(optimizer); - auto common_process = std::make_shared("graph_kernel_common_process"); - MS_EXCEPTION_IF_NULL(common_process); - common_process->AddPass(std::make_shared()); - common_process->AddPass(std::make_shared()); - optimizer->AddPassManager(common_process); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - -void AscendDataLayout(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto optimizer = std::make_shared(); - auto data_layout_pm = std::make_shared("transop_pm"); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(data_layout_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - -void AscendMixPrecision(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto optimizer = std::make_shared(); - auto mixed_precision_pm = std::make_shared("cast_pm"); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(mixed_precision_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - -void AscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before" + "_graph_" + - std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - DumpIRProto(kernel_graph, "before_hwopt_" + std::to_string(kernel_graph->graph_id())); - } - auto optimizer = std::make_shared(); - auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - if (context_ptr->ir_fusion_flag()) { - AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); - } - - if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - } - ir_fusion_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(ir_fusion_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - if (save_graphs) { - std::string file_path = - save_graphs_path + "/" + "hwopt_d_ir_fusion_after" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } -} - -void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->ir_fusion_flag()) { - MS_LOG(INFO) << "IRFusion is not enable, skip"; - return; - } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before.ir"; - DumpIR(file_path, kernel_graph); - } - auto optimizer = std::make_shared(); - auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - - optimizer->AddPassManager(ir_fusion_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after.ir"; - DumpIR(file_path, kernel_graph); - } -} - -void AscendBackendOptimization(const std::shared_ptr &kernel_graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = - save_graphs_path + "/" + "hwopt_d_before" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } - // data layout optimization - AscendDataLayout(kernel_graph); - // mixed precision optimization - AscendMixPrecision(kernel_graph); - // other optimization - auto optimizer = std::make_shared(); - auto other_pm = std::make_shared("other_pm"); - other_pm->AddPass(std::make_shared()); - other_pm->AddPass(std::make_shared()); - other_pm->AddPass(std::make_shared()); - other_pm->AddPass(std::make_shared()); - other_pm->AddPass(std::make_shared()); - other_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(other_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - // buffer fusion - AscendBackendUBFusionOptimization(kernel_graph); - - // other2 optimization - auto optimizer2 = std::make_shared(); - auto other2_pm = std::make_shared("other2_pm"); - other2_pm->AddPass(std::make_shared()); - other2_pm->AddPass(std::make_shared()); - if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { - other2_pm->AddPass(std::make_shared()); - } - other2_pm->AddPass(std::make_shared()); - optimizer2->AddPassManager(other2_pm); - (void)optimizer2->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - - if (save_graphs) { - std::string file_path = - save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph, true); - DumpIRProto(kernel_graph, "after_hwopt"); - kernel_graph->DumpFuncGraph("hwopt_d_end"); - } -} - -void AscendBackendGraphKernelOpt(const std::shared_ptr &kernel_graph, - bool is_before_kernel_select) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->enable_graph_kernel())) { - return; - } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_before_graph_" + - std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + - ".ir"; - DumpIR(file_path, kernel_graph); - } - - // Fuse graph kernels with basic ops - FuseGraphKernel(kernel_graph, is_before_kernel_select); - - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" + - std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + - ".ir"; - DumpIR(file_path, kernel_graph, true); - } -} - -void AscendBackendFuseBasicOpt(const std::shared_ptr &kernel_graph, - bool is_before_kernel_select) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->enable_graph_kernel())) { - return; - } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_before_graph_" + - std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + - ".ir"; - DumpIR(file_path, kernel_graph, true); - } - - // Fuse basic ops with basic ops - FuseBasic(kernel_graph, is_before_kernel_select); - - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_end_graph_" + - std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + - ".ir"; - DumpIR(file_path, kernel_graph, true); - } -} - -void AscendBackendAddAtomicClean(const std::shared_ptr &kernel_graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->enable_graph_kernel())) { - return; - } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_add_atomic_clean_before" + "_graph_" + - std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } - - AddAtomicClean(kernel_graph); - - if (save_graphs) { - std::string file_path = - save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph, true); - } -} - -void AscendBackendUBFusionOptimization(const std::shared_ptr &kernel_graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->ir_fusion_flag()) { - MS_LOG(INFO) << "UBFusion is not enable, skip"; - return; - } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = - save_graphs_path + "/hwopt_d_ub_fusion_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } - auto fusion_id_allocator = std::make_shared(); - MS_EXCEPTION_IF_NULL(fusion_id_allocator); - fusion_id_allocator->Init(); - auto optimizer = std::make_shared(); - auto ub_fusion_pm = std::make_shared("ub_fusion_pm"); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(ub_fusion_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - if (save_graphs) { - std::string file_path = - save_graphs_path + "/hwopt_d_ub_fusion_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc deleted file mode 100644 index 3241684c62..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc +++ /dev/null @@ -1,226 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" -#include -#include -#include -#include -#include "kernel/oplib/oplib.h" -#include "session/anf_runtime_algorithm.h" -#include "session/kernel_graph.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { - session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); - AnfNodePtr cur_node = kernel_with_index.first; - size_t cur_out_index = kernel_with_index.second; - MS_EXCEPTION_IF_NULL(cur_node); - if (cur_node->isa()) { - auto cnode = cur_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - std::string op_name = AnfAlgo::GetCNodeName(cnode); - auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); - // deal ref op - if (op_info != nullptr && op_info->is_ref()) { - auto ref_infos = op_info->ref_infos(); - if (ref_infos.count(cur_out_index) != 0) { - auto in_index = ref_infos.at(cur_out_index); - if (in_index > cnode->inputs().size()) { - MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size() - << ", ref info is " << cur_out_index; - } - AnfNodePtr next_node = cnode->input(in_index + 1); - return FindRefOriginNode(next_node); - } - } - - // deal special (trans,cast,reshape) op - if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() || - op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) { - AnfNodePtr next_node = cnode->input(1); - return FindRefOriginNode(next_node); - } - } - - return kernel_with_index; -} - -void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, - const AnfNodePtr &final_node, size_t final_index, - const session::KernelWithIndex &origin_pair) { - // record the ref_pair - auto kernel_graph = func_graph->cast(); - MS_EXCEPTION_IF_NULL(kernel_graph); - // if the final node is get item, means no trans or cast op is added, the final node is itself - // so add the pair for itself, because the get item will removed later - auto final_ref = (final_node == get_item ? cnode : final_node); - session::AnfWithOutIndex final_pair = std::make_pair(final_ref, final_index); - if (kernel_graph->IsInRefOutputMap(final_pair)) { - MS_LOG(EXCEPTION) << "ref_pair is already in ref map, node is " << final_ref->DebugString() << ", index is " - << final_index; - } - MS_LOG(DEBUG) << "Add Ref pair, final {node ptr " << final_pair.first.get() << " , info is " - << final_pair.first->DebugString() << " , index is " << final_pair.second << "}, origin {node ptr " - << origin_pair.first.get() << ", info is " << origin_pair.first->DebugString() << " : index " - << origin_pair.second << "}"; - kernel_graph->AddRefCorrespondPairs(final_pair, origin_pair); -} - -// if get_item is nullptr, the additional node will link to the cnode -// else the additional node will link to the get_item node (the get_item node link to cnode) -AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, - size_t input_index, const AnfNodePtr &get_item) { - AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); - size_t final_index = output_index; - AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); - session::KernelWithIndex origin_pair; - origin_pair = FindRefOriginNode(input_node); - MS_EXCEPTION_IF_NULL(origin_pair.first); - if (!origin_pair.first->isa()) { - MS_LOG(WARNING) << "ref op origin node is not parameter"; - } - MS_LOG(DEBUG) << "DealRefTransAndCast the node input index " << input_index << ", find origin op is " - << origin_pair.first->DebugString() << ", index is " << origin_pair.second; - auto origin_format = AnfAlgo::GetOutputFormat(origin_pair.first, origin_pair.second); - auto origin_type = AnfAlgo::GetOutputDeviceDataType(origin_pair.first, origin_pair.second); - auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index); - auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index); - auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, output_index); - // insert trans - if (origin_format != cur_format && cur_shape.size() > 1) { - auto kernel_select = std::make_shared(); - final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(cur_format, origin_format, final_node); - final_index = 0; - MS_EXCEPTION_IF_NULL(final_node); - MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); - } - // insert cast - if (origin_type != cur_type) { - final_node = - AddCastOpNodeToGraph(func_graph, final_node, origin_format, cur_type, origin_type, cur_shape, cur_type); - MS_EXCEPTION_IF_NULL(final_node); - final_node->set_scope(cnode->scope()); - final_index = 0; - MS_LOG(INFO) << "DealRefTransAndCast add cast op, op debug info is " << final_node->DebugString(); - } - // add ref pair - AddRefPairToKernelGraph(func_graph, cnode, get_item, final_node, final_index, origin_pair); - // insert depend - if (origin_format != cur_format || origin_type != cur_type) { - std::vector depend_nodes{NewValueNode(prim::kPrimDepend), cnode, final_node}; - final_node = func_graph->NewCNode(depend_nodes); - MS_LOG(INFO) << "DealRefTransAndCast add denpend, op debug info is " << final_node->DebugString(); - } - - return final_node; -} -AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::shared_ptr &op_info) { - MS_EXCEPTION_IF_NULL(op_info); - auto ref_infos = op_info->ref_infos(); - std::vector make_tuple_inputs; - AbstractBasePtrList abstract_list; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { - AnfNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index); - // deal with ref output - if (ref_infos.count(output_index) != 0) { - auto input_index = ref_infos.at(output_index); - final_node = AddAdditionalToRefOutput(func_graph, cnode, output_index, input_index, final_node); - } - MS_EXCEPTION_IF_NULL(final_node); - abstract_list.push_back(final_node->abstract()); - make_tuple_inputs.push_back(final_node); - } - MS_EXCEPTION_IF_NULL(func_graph); - AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - make_tuple->set_abstract(std::make_shared(abstract_list)); - return make_tuple; -} - -AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::shared_ptr &op_info) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(op_info); - auto ref_infos = op_info->ref_infos(); - for (const auto &ref_info : ref_infos) { - if (ref_info.second > cnode->inputs().size()) { - MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size() << ", ref info is " - << ref_info.second; - } - return AddAdditionalToRefOutput(func_graph, cnode, ref_info.first, ref_info.second, nullptr); - } - return nullptr; -} -} // namespace - -const BaseRef DealRefTransAndCast::DefinePattern() const { - VarPtr V = std::make_shared(UnVisited); - VarPtr Xs = std::make_shared(); - return VectorRef({V, Xs}); -} - -void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { - auto input_size = AnfAlgo::GetInputTensorNum(cnode); - for (size_t i = 0; i < input_size; ++i) { - auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i); - auto input_node = input_node_with_index.first; - MS_EXCEPTION_IF_NULL(input_node); - MS_LOG(INFO) << "origin node:" << input_node->fullname_with_scope(); - AddRefPairToKernelGraph(func_graph, cnode, nullptr, cnode, i, input_node_with_index); - } - } -} - -const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa()) { - return nullptr; - } - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!AnfAlgo::IsRealCNodeKernel(cnode)) { - return nullptr; - } - - DealBroadCastAsRef(graph, cnode); - - auto op_name = AnfAlgo::GetCNodeName(cnode); - auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); - if (op_info == nullptr || !op_info->is_ref()) { - return nullptr; - } - if (op_info->is_ref()) { - auto type = cnode->Type(); - MS_EXCEPTION_IF_NULL(type); - if (!type->isa()) { - return DealRefSigleOutput(graph, cnode, op_info); - } else { - return DealRefForMultipleOutput(graph, cnode, op_info); - } - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc deleted file mode 100644 index 89834cbc65..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/convert_const_input_to_attr.h" - -#include -#include -#include -#include - -#include "pre_activate/pass/const_input_to_attr_registry.h" -#include "pre_activate/common/helper.h" -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "operator/ops.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace opt { -const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { - return nullptr; - } - std::vector todos; - if (AnfAlgo::IsGraphKernel(node)) { - auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(sub_graph); - kernel::GetValidKernelNodes(sub_graph, &todos); - } else { - todos.push_back(node); - } - - for (auto &t : todos) { - CNodePtr cnode = t->cast(); - ConstInputToAttrInfoRegister reg; - if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), ®)) { - continue; - } - if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() || - AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) { - if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) { - continue; - } - } - ConstInputToAttr(cnode, reg.GetConstInputAttrInfo()); - } - return node; -} -} // namespace opt -} // namespace mindspore