Browse Source

!279 fox folder location

Merge pull request !279 from changzherui/fix_fold_loc
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
d6fcf731ec
6 changed files with 0 additions and 1781 deletions
  1. +0
    -312
      mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc
  2. +0
    -65
      mindspore/ccsrc/kernel/aicpu/aicpu_util.h
  3. +0
    -622
      mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc
  4. +0
    -492
      mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
  5. +0
    -226
      mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc
  6. +0
    -64
      mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc

+ 0
- 312
mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc View File

@@ -1,312 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/aicpu/aicpu_kernel_build.h"
#include <google/protobuf/text_format.h>
#include <fstream>
#include <utility>
#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include <map>
#include "device/kernel_runtime.h"
#include "kernel/aicpu/aicpu_kernel_mod.h"
#include "kernel/akg/akg_kernel_build.h"
#include "proto/tensor.pb.h"
#include "proto/tensor_shape.pb.h"
#include "proto/attr.pb.h"
#include "proto/node_def.pb.h"
#include "session/anf_runtime_algorithm.h"
#include "common/utils.h"
#include "kernel/aicpu/aicpu_util.h"
#include "session/kernel_graph.h"
#include "kernel/common_utils.h"

namespace mindspore {
namespace kernel {
using FNodeAttrHandle = std::function<void(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto)>;

bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input_num,
std::vector<size_t> *input_size_list) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(input_size_list);
for (size_t i = 0; i < input_num; i++) {
std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i);
if (AnfAlgo::GetInputDeviceDataType(anf_node, i) == kObjectTypeString) {
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "anf_node is not CNode.";
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() < (i + 1)) {
MS_LOG(ERROR) << "cnode inputs size " << cnode->inputs().size() << " is smaller than " << i + 1;
return false;
}
auto input_node = cnode->inputs()[i + 1];
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<ValueNode>()) {
auto value_ptr = GetValueNode(input_node);
auto value = GetValue<std::string>(value_ptr);
input_size_list->push_back(value.size());
}
} else {
auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i));
MS_EXCEPTION_IF_NULL(type_ptr);
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
input_size_list->push_back(LongToSize(size_i));
}
}
return true;
}

bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<AicpuOpKernelMod> &kernel_mod_ptr) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
std::vector<size_t> input_size_list;
std::vector<size_t> output_size_list;
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);

if (!SetIOIputSize(anf_node, input_num, &input_size_list)) {
return false;
}
kernel_mod_ptr->SetInputSizeList(input_size_list);

for (size_t i = 0; i < output_num; i++) {
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i));
MS_EXCEPTION_IF_NULL(type_ptr);
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
output_size_list.push_back(LongToSize(size_i));
}
kernel_mod_ptr->SetOutputSizeList(output_size_list);
return true;
}

void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value,
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) {
MS_EXCEPTION_IF_NULL(node_attr);
MS_EXCEPTION_IF_NULL(value);
if (type == "int") {
auto attr_value = GetValue<int>(value);
(*node_attr)[attr_name].set_i(attr_value);
} else if (type == "str") {
auto attr_value = GetValue<std::string>(value);
(*node_attr)[attr_name].set_s(attr_value);
} else if (type == "bool") {
auto attr_value = GetValue<bool>(value);
(*node_attr)[attr_name].set_b(attr_value);
} else if (type == "float") {
auto attr_value = GetValue<float>(value);
(*node_attr)[attr_name].set_f(attr_value);
} else if (type == "listInt") {
std::vector<int> attr_value;
auto value_type = value->type();
MS_EXCEPTION_IF_NULL(value_type);
auto value_type_str = value_type->ToString();
if (value_type_str == "Int32") {
int data = GetValue<int>(value);
attr_value.push_back(data);
} else {
attr_value = GetValue<std::vector<int>>(value);
}
mindspore::AttrValue input_shape_attr;
mindspore::AttrValue_ArrayValue *input_shape_attr_list = input_shape_attr.mutable_array();
MS_EXCEPTION_IF_NULL(input_shape_attr_list);
for (const auto shape : attr_value) {
input_shape_attr_list->add_i(shape);
}
(*node_attr)[attr_name] = input_shape_attr;
} else {
MS_LOG(EXCEPTION) << "type: " << type << "not support";
}
}

void SetNodeAttr(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(proto);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
if (op_name == kInitDataSetQueue) {
op_name = kInitData;
}
if (op_name == kPrint) {
return;
}

auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU);
MS_EXCEPTION_IF_NULL(op_info_ptr);
auto attrs_ptr = op_info_ptr->attrs_ptr();
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs();
for (const auto &attr_ptr : attrs_ptr) {
MS_EXCEPTION_IF_NULL(attr_ptr);
std::string attr_name = attr_ptr->name();
auto value = primitive->GetAttr(attr_name);
if (value != nullptr) {
if (attr_name == kQueueName || attr_name == kSharedName) {
attr_name = kChannelName;
} else if (attr_name == kSeed0) {
attr_name = kSeed;
} else if (attr_name == kSeed1) {
attr_name = kSeed2;
}
std::string type = attr_ptr->type();
ParseAttrValue(type, attr_name, value, node_attr);
}
}
MS_LOG(INFO) << "Set node attr end!";
}

void SetNodeInputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
MS_EXCEPTION_IF_NULL(proto);
MS_EXCEPTION_IF_NULL(anf_node);
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
if (input_num == 0) {
MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input.";
return;
}

for (size_t input_index = 0; input_index < input_num; input_index++) {
::mindspore::Tensor *node_inputs = proto->add_inputs();
MS_EXCEPTION_IF_NULL(node_inputs);
TypeId input_type = AnfAlgo::GetInputDeviceDataType(anf_node, input_index);
std::vector<size_t> input_shape;
int32_t input_data_type;
if (input_type == kObjectTypeString) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_node = cnode->inputs()[input_index + 1];
auto value_ptr = GetValueNode(input_node);
auto value = GetValue<std::string>(value_ptr);
input_shape.push_back(1);
input_shape.push_back(value.size());
input_data_type = AicpuOpUtil::MsTypeToProtoType(kTypeUnknown);
} else {
input_shape = AnfAlgo::GetInputDeviceShape(anf_node, input_index);
input_data_type = AicpuOpUtil::MsTypeToProtoType(input_type);
}

mindspore::TensorShape *tensorShape = node_inputs->mutable_tensor_shape();
for (auto item : input_shape) {
mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
dim->set_size((::google::protobuf::int64)item);
}
node_inputs->set_tensor_type((mindspore::DataType)input_data_type);
node_inputs->set_mem_device("HBM");
}
}

void SetNodeOutputs(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
MS_EXCEPTION_IF_NULL(proto);
MS_EXCEPTION_IF_NULL(anf_node);
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
if (output_num == 0) {
MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have output. ";
return;
}

for (size_t output_index = 0; output_index < output_num; output_index++) {
::mindspore::Tensor *node_outputs = proto->add_outputs();
MS_EXCEPTION_IF_NULL(node_outputs);
std::vector<size_t> output_shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index);
mindspore::TensorShape *tensorShape = node_outputs->mutable_tensor_shape();
MS_EXCEPTION_IF_NULL(tensorShape);
for (auto item : output_shape) {
mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
MS_EXCEPTION_IF_NULL(dim);
dim->set_size((::google::protobuf::int64)item);
}
TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, output_index);
int32_t output_data_type = AicpuOpUtil::MsTypeToProtoType(output_type);
node_outputs->set_tensor_type((mindspore::DataType)output_data_type);
node_outputs->set_mem_device("HBM");
}
}

void SetNodedefProto(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(proto);
MS_LOG(INFO) << "SetNodedefProto entry";
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
if (op_name == kInitDataSetQueue) {
op_name = kInitData;
}
// set op name
proto->set_op(op_name);
// set inputs tensor
SetNodeInputs(anf_node, proto);
// set outputs tensor
SetNodeOutputs(anf_node, proto);
// set node attr
SetNodeAttr(anf_node, proto);
MS_LOG(INFO) << "SetNodedefProto end!";
}

bool CreateNodeDefBytes(const std::shared_ptr<AnfNode> &anf_node,
const std::shared_ptr<AicpuOpKernelMod> &kernel_mod_ptr) {
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
MS_EXCEPTION_IF_NULL(anf_node);
MS_LOG(INFO) << "CreateNodeDefBytes entry";

mindspore::NodeDef proto;
SetNodedefProto(anf_node, &proto);
std::string nodeDefStr;
if (!proto.SerializeToString(&nodeDefStr)) {
MS_LOG(ERROR) << "Serialize nodeDef to string failed.";
return false;
}
kernel_mod_ptr->SetNodeDef(nodeDefStr);
MS_LOG(INFO) << "CreateNodeDefBytes end!";
return true;
}

KernelModPtr AicpuOpBuild(const std::shared_ptr<AnfNode> &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
if (op_name == kInitDataSetQueue) {
op_name = kInitData;
}
auto kernel_mod_ptr = std::make_shared<AicpuOpKernelMod>();
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
kernel_mod_ptr->SetAnfNode(anf_node);
kernel_mod_ptr->SetNodeName(op_name);
if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) {
MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!";
}
if (!SetIOSize(anf_node, kernel_mod_ptr)) {
MS_LOG(EXCEPTION) << "Set input output size list failed.";
}
return kernel_mod_ptr;
}
} // namespace kernel
} // namespace mindspore

+ 0
- 65
mindspore/ccsrc/kernel/aicpu/aicpu_util.h View File

@@ -1,65 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_
#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_

#include <cstdint>
#include <vector>
#include <map>
#include <string>
#include "kernel/kernel.h"

namespace mindspore {
namespace kernel {
constexpr auto kInitDataSetQueue = "InitDataSetQueue";
constexpr auto kInitData = "InitData";
constexpr auto kGetNext = "GetNext";
constexpr auto kPrint = "Print";
constexpr auto kPack = "Pack";

constexpr auto kOutputTypes = "output_types";
constexpr auto kOutputShapes = "output_shapes";
constexpr auto kChannelName = "channel_name";
constexpr auto kSharedName = "shared_name";
constexpr auto kShapes = "shapes";
constexpr auto kTypes = "types";
constexpr auto kQueueName = "queue_name";
constexpr auto kSeed = "seed";
constexpr auto kSeed0 = "Seed0";
constexpr auto kSeed1 = "Seed1";
constexpr auto kSeed2 = "seed2";
constexpr auto kTopK = "TopK";
constexpr auto kTopKV2 = "TopKV2";

struct AicpuParamHead {
uint32_t length; // Total length: include cunstom message
uint32_t ioAddrNum; // Input and output address number
uint32_t extInfoLength; // extInfo struct Length
uint64_t extInfoAddr; // extInfo address
} __attribute__((packed));

class AicpuOpUtil {
public:
static int MsTypeToProtoType(TypeId ms_type);

private:
// kernel id
static uint64_t KernelId_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_

+ 0
- 622
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc View File

@@ -1,622 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
#include <memory>
#include <map>
#include <set>
#include <utility>
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
#include "kernel/tbe/tbe_kernel_build.h"
#include "nlohmann/json.hpp"
#include "utils/context/ms_context.h"
#include "kernel/tbe/tbe_python_funcs.h"
#include "pre_activate/common/helper.h"
#include "kernel/tbe/tbe_convert_utils.h"
#include "parallel/ops_info/ops_utils.h"
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"

namespace mindspore {
namespace kernel {
constexpr auto kName = "name";
constexpr auto kDtype = "dtype";
constexpr auto kFormat = "format";
constexpr auto kPrefixInput = "input";
constexpr auto kPrefixOutput = "output";
constexpr char kParamTypeDynamic[] = "dynamic";
constexpr char kParamTypeRequre[] = "required";
constexpr char kParamTypeOptional[] = "optional";
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list);
tbe_selecter.TbeMetadataInfoEx();
}

TbeKernelSelect::TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list)
: cnode_ptr_(std::move(kernel_node)), kernel_info_list_(kernel_info_list) {}

void TbeKernelSelect::TbeMetadataInfoEx() {
MS_EXCEPTION_IF_NULL(cnode_ptr_);
MS_EXCEPTION_IF_NULL(kernel_info_list_);
node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_);
auto op_info_ptr = OpLib::FindOp(node_name_, kTBE);
if (!op_info_ptr) {
MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_;
return;
}
MS_LOG(INFO) << "Start to tbe metadata info. node type: " << node_name_
<< ", node name: " << cnode_ptr_->fullname_with_scope();
OpPattern pattern = op_info_ptr->op_pattern();
if (pattern == kCommonPattern) {
GetCommonPatternKernelInfo(*op_info_ptr);
} else if (pattern == kDynamicFormatPattern) {
GetDynamicFormatPatternKernelInfo(*op_info_ptr);
} else if (pattern == kFormatAgnosticPattern) {
GetAgnosticPatternKernelInfo(*op_info_ptr);
} else if (pattern == kBroadcastPattern) {
GetBroadcastPatternKernelInfo(*op_info_ptr);
} else if (pattern == kReducePattern) {
GetReducePatternKernelInfo(*op_info_ptr);
} else {
MS_LOG(INFO) << "Warning: op pattern is invailed.";
}
// check support
FilterInVaildKernelInfo();
MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select.";
}

void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
MS_LOG(INFO) << "start.";
// get dynamic inputs
auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_);
MS_EXCEPTION_IF_NULL(primitive);
std::vector<int> dyn_input_sizes;
if (primitive->HasAttr(kAttrDynInputSizes)) {
dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
}
// get real input/output num
size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_);
const auto inputs_info = op_info.inputs_ptr();
size_t real_output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
const auto outputs_info = op_info.outputs_ptr();
if (inputs_info.empty() && outputs_info.empty()) {
MS_LOG(EXCEPTION) << "op info input & output is null, please check.";
}
// create kernel build info from opinfo
size_t kernel_build_info_num =
inputs_info.empty() ? outputs_info[0]->dtypes().size() : inputs_info[0]->dtypes().size();
for (size_t kernel_build_info_index = 0; kernel_build_info_index < kernel_build_info_num; ++kernel_build_info_index) {
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
SetTbeBuildCommonInfo(op_info, &builder);
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_device_type;
std::vector<std::vector<Axis>> inputs_reshape_type;
// input
if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes,
&inputs_format, &inputs_device_type, &inputs_reshape_type)) {
break;
}
builder.SetInputsDeviceType(inputs_device_type);
builder.SetInputsFormat(inputs_format);
builder.SetInputReshapeType(inputs_reshape_type);
// output
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_device_type;
std::vector<std::vector<Axis>> outputs_reshape_type;
if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes,
&outputs_format, &outputs_device_type, &outputs_reshape_type)) {
break;
}
builder.SetOutputsDeviceType(outputs_device_type);
builder.SetOutputsFormat(outputs_format);
builder.SetOutputReshapeType(outputs_reshape_type);
kernel_info_list_->emplace_back(builder.Build());
}
MS_LOG(INFO) << "end.";
}

void TbeKernelSelect::GetDynamicFormatPatternKernelInfo(const OpInfo &op_info) {
MS_LOG(INFO) << "start.";
//
OpInfo op_info_new;
CreateNewOpInfo(op_info, &op_info_new);
GetCommonPatternKernelInfo(op_info_new);
MS_LOG(INFO) << "end.";
}

void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) {
MS_LOG(INFO) << "start.";
if (op_info.inputs_ptr().size() != 1) {
MS_LOG(EXCEPTION) << "AgnosticPattern only support one input.";
}
auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0);
if (kOpFormatList.find(format) == kOpFormatList.end()) {
MS_LOG(INFO) << "Got the unknown format " << format;
format = kOpFormat_DEFAULT;
}
SupportFormat support_format;
SupportFormatItem input_item;
SupportFormatItem output_item;
input_item.assign(op_info.inputs_ptr().size(), format);
output_item.assign(op_info.outputs_ptr().size(), format);
support_format.input_format.emplace_back(input_item);
support_format.output_format.emplace_back(output_item);
PrintSupportedFormat(support_format);
OpInfo op_info_new;
CreateNewOpInfo(op_info, support_format, &op_info_new);
GetCommonPatternKernelInfo(op_info_new);
MS_LOG(INFO) << "end.";
}

void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) {
MS_LOG(INFO) << "start.";
auto broadcast_selecter = TbeKernelBroadCastSelecter(cnode_ptr_);
SupportFormat support_format;
broadcast_selecter.GetShapeInfo(&support_format);
if (!broadcast_selecter.IsBroadCastSupport5HD(&support_format)) {
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support 5HD.";
}
if (!broadcast_selecter.IsBroadCastSupportFracZ(&support_format)) {
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracZ.";
}
if (!broadcast_selecter.IsBroadCastSupportC1HWNCoC0(&support_format)) {
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support C1HWNCoC0.";
}
if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) {
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ.";
}
PrintSupportedFormat(support_format);
OpInfo op_info_new;
CreateNewOpInfo(op_info, support_format, &op_info_new);
GetCommonPatternKernelInfo(op_info_new);
MS_LOG(INFO) << "end.";
}

void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) {
MS_LOG(INFO) << "start.";
auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_);
SupportFormat support_format;
reduce_selecter.GetShapeInfo(&support_format);
if (!reduce_selecter.IsReduceSupport5HD(&support_format)) {
MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support 5HD.";
}
if (reduce_selecter.IsReduceSupportFracZ(&support_format)) {
MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracZ.";
}
if (reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format)) {
MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support C1HWNCoC0.";
}
if (reduce_selecter.IsReduceSupportFracNZ(&support_format)) {
MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracNZ.";
}
PrintSupportedFormat(support_format);
OpInfo op_info_new;
CreateNewOpInfo(op_info, support_format, &op_info_new);
GetCommonPatternKernelInfo(op_info_new);
MS_LOG(INFO) << "end.";
}

void TbeKernelSelect::FilterInVaildKernelInfo() {
if (kernel_info_list_->empty()) {
MS_LOG(INFO) << "Warning: get kernel build info failed.";
return;
}
auto kernel_build_info_iter = kernel_info_list_->begin();
while (kernel_build_info_iter != kernel_info_list_->end()) {
if (!FilterInVaildShape(kernel_build_info_iter)) {
MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*kernel_build_info_iter)->ToString();
kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter);
continue;
}
if (!TbeCheckSupported(kernel_build_info_iter)) {
MS_LOG(INFO) << "Check support shape, filter item info: " << (*kernel_build_info_iter)->ToString();
kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter);
continue;
}
kernel_build_info_iter++;
}
}

bool TbeKernelSelect::FilterInVaildShape(
const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) {
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
auto kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats();
for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
auto format = kernel_build_info_inputs_format.at(i);
if (!IsShapeMatchFormat(shape, format)) {
MS_LOG(INFO) << "The " << i << "th input check failed.";
return false;
}
}
auto kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats();
for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) {
auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j);
auto format = kernel_build_info_outputs_format.at(j);
if (!IsShapeMatchFormat(shape, format)) {
MS_LOG(INFO) << "The " << j << "th input check failed.";
return false;
}
}
return true;
}

bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
if (format == kOpFormat_DEFAULT) {
return true;
}
static std::set<std::string> kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
// if format is default, it remarkes support all format
if (kOpFormatList.find(format) == kOpFormatList.end()) {
MS_LOG(EXCEPTION) << "Got the unknown format " << format;
}
// server not support format with C04 suffix
if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) !=
kServerNotSupportFormat.end()) {
MS_LOG(INFO) << "Warning: Server not support format with C04 suffix.";
return false;
}
// not support format:
// 1 NDHWC with shape size != 5
// 2 FRAC_NZ with shape size < 2
// 3 !NDHWC with shape size > 4
if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) ||
(format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) ||
(format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) {
MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size();
return false;
}
return true;
}

bool TbeKernelSelect::TbeCheckSupported(
const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) {
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
static const std::set<std::string> kCheckSupportedOpType = {parallel::MATMUL,
parallel::BATCHMATMUL,
parallel::TOPK,
parallel::IN_TOPK,
parallel::PACK,
parallel::UNSORTEF_SEGMENT_MIND,
parallel::UNSORTEF_SEGMENT_PRODD,
parallel::CAST};
auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_);
if (iter == kCheckSupportedOpType.end()) {
return true;
}
MS_LOG(INFO) << "Check support start.";
// replace kernel_info with current kernel info
auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_);
AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get());
nlohmann::json kernel_json;
TbeKernelJsonCreator creator(CHECK_SUPPORTED);
bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json);
if (!ret) {
MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed.";
}
ret = TbePythonFuncs::CheckSupported(kernel_json);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get());
return ret;
}

void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_info,
mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder *builder) {
MS_EXCEPTION_IF_NULL(builder);
builder->SetProcessor(AICORE);
std::string fusion_type = op_info.fusion_type();
if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) {
builder->SetFusionType(tbe::GetFusionType(fusion_type));
}
builder->SetOpPattern(op_info.op_pattern());
builder->SetKernelType(TBE_KERNEL);
}

bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
const std::vector<std::shared_ptr<OpIOInfo>> &ios_info,
const std::vector<int> &dyn_input_sizes, std::vector<std::string> *formats,
std::vector<TypeId> *device_types, std::vector<std::vector<Axis>> *reshape_types) {
MS_EXCEPTION_IF_NULL(formats);
MS_EXCEPTION_IF_NULL(device_types);
MS_EXCEPTION_IF_NULL(reshape_types);
size_t dynamic_input_index = 0;
size_t real_io_tensor_index = 0;
size_t io_info_index = 0;
size_t io_info_num = ios_info.size();
for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) {
std::shared_ptr<OpIOInfo> io_info_item = ios_info[io_info_index];
auto kernel_build_info_dtype = io_info_item->dtypes().at(kernel_build_info_index);
std::string kernel_build_info_format;
if (!io_info_item->formats().empty()) {
kernel_build_info_format = io_info_item->formats().at(kernel_build_info_index);
}
std::string io_param_type = io_info_item->param_type();
std::vector<Axis> reshape_type;
StringToAxisVector(io_info_item->reshape_type(), &reshape_type);
if (io_param_type == kParamTypeDynamic) {
// dynamic io
if (is_input) {
if (dynamic_input_index >= dyn_input_sizes.size()) {
MS_LOG(EXCEPTION) << "dyn_input_sizes attr set error, dynamic_input_index: " << dynamic_input_index
<< ", dyn_input_sizes size: " << dyn_input_sizes.size();
}
int dynamic_input_size = dyn_input_sizes[dynamic_input_index];
for (int i = 0; i < dynamic_input_size; ++i) {
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
formats->emplace_back(kernel_build_info_format);
reshape_types->emplace_back(reshape_type);
}
dynamic_input_index++;
real_io_tensor_index += dynamic_input_size;
} else {
if (ios_info.size() != 1) {
MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output.";
}
for (size_t i = 0; i < real_io_tensor_num; ++i) {
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
formats->emplace_back(kernel_build_info_format);
reshape_types->emplace_back(reshape_type);
}
real_io_tensor_index += real_io_tensor_num;
}
} else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) {
// requre or optional io
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
formats->emplace_back(kernel_build_info_format);
reshape_types->emplace_back(reshape_type);
real_io_tensor_index++;
} else {
MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type;
}
}

if (io_info_index != io_info_num) {
MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num
<< "), this node may has optional input/output.";
}
if (real_io_tensor_index != real_io_tensor_num) {
std::string io_type = is_input ? "inputs " : "outputs";
MS_LOG(INFO) << node_name_ << "'s " << io_type << "op io info num: " << io_info_num
<< ", real io tensor num:" << real_io_tensor_num << "real_io_tensor_index(" << real_io_tensor_index
<< ") != real_io_tensor_num(" << real_io_tensor_num << ")";
return false;
}
return true;
}

void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
MS_EXCEPTION_IF_NULL(reshape_type_vec);
for (const auto &c : reshape_type_str) {
switch (c) {
case 'N':
reshape_type_vec->push_back(kernel::N);
break;
case 'C':
reshape_type_vec->push_back(kernel::C);
break;
case 'H':
reshape_type_vec->push_back(kernel::H);
break;
case 'W':
reshape_type_vec->push_back(kernel::W);
break;
default:
MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
}
}
}

void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info,
const std::vector<std::vector<std::string>> &support_format_item, size_t index,
mindspore::kernel::OpIOInfo *op_io_info_new) {
MS_EXCEPTION_IF_NULL(op_io_info_new);
op_io_info_new->set_index(op_io_info.index());
op_io_info_new->set_name(op_io_info.name());
op_io_info_new->set_param_type(op_io_info.param_type());
op_io_info_new->set_need_compile(op_io_info.need_compile());
op_io_info_new->set_reshape_type(op_io_info.reshape_type());
op_io_info_new->set_shape(op_io_info.shape());
// dtype
std::vector<std::string> dtype_new;
auto dtype = op_io_info.dtypes();
for (size_t i = 0; i < support_format_item.size(); ++i) {
dtype_new.insert(dtype_new.end(), dtype.begin(), dtype.end());
}
op_io_info_new->set_dtypes(dtype_new);
// format
std::vector<std::string> format_new;
for (const auto &formats : support_format_item) {
auto format = formats.at(index);
for (size_t j = 0; j < dtype.size(); ++j) {
format_new.emplace_back(format);
}
}
op_io_info_new->set_formats(format_new);
}

std::vector<std::string> TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) {
const std::map<std::string, std::string> kDynamicFormatMap = {
{"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}};
if (op_select_json_item.empty()) {
MS_LOG(EXCEPTION) << "Op select ret item is null.";
}
const char space = ' ';
const char sep = ',';
std::string op_select_tmp = op_select_json_item + ",";
std::vector<std::string> ret;
auto begin = op_select_tmp.find_first_not_of(space, 0);
auto sep_pos = op_select_tmp.find(sep);
if (begin >= sep_pos) {
MS_LOG(EXCEPTION) << "Select ret json is error.";
}
while (sep_pos != std::string::npos) {
auto obj = op_select_tmp.substr(begin, sep_pos - begin);
if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) {
obj = kDynamicFormatMap.at(obj);
}
ret.emplace_back(obj);
begin = op_select_tmp.find_first_not_of(space, sep_pos + 1);
sep_pos = op_select_tmp.find(sep, begin);
}
return ret;
}

std::string TbeKernelSelect::OpSelectFormat() {
nlohmann::json kernel_json;
std::string res_json_str;
TbeKernelJsonCreator creator(OP_SELECT_FORMAT);
bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json);
if (!ret) {
MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed.";
}
res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json);
if (res_json_str.empty()) {
MS_LOG(EXCEPTION) << "op select format error.";
}
MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str;
return res_json_str;
}

void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, const SupportFormat &support_format,
mindspore::kernel::OpInfo *op_info_new) {
MS_EXCEPTION_IF_NULL(op_info_new);
if (op_info.inputs_ptr().size() != support_format.input_format[0].size() ||
op_info.outputs_ptr().size() != support_format.output_format[0].size()) {
MS_LOG(EXCEPTION) << "BroadCast input/output size not match, op info input size:" << op_info.inputs_ptr().size()
<< ", input support size: " << support_format.input_format[0].size()
<< ", op info output size: " << op_info.outputs_ptr().size()
<< ", output support size: " << support_format.output_format[0].size();
}
*op_info_new = op_info;
op_info_new->ClearInputs();
op_info_new->ClearOutputs();
for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) {
auto input = op_info.inputs_ptr().at(i);
auto input_new = std::make_shared<OpIOInfo>();
CreateNewOpIOInfo(*input, support_format.input_format, i, input_new.get());
op_info_new->add_inputs_ptr(input_new);
}
for (size_t j = 0; j < op_info.outputs_ptr().size(); ++j) {
auto output = op_info.outputs_ptr().at(j);
auto output_new = std::make_shared<OpIOInfo>();
CreateNewOpIOInfo(*output, support_format.output_format, j, output_new.get());
op_info_new->add_outputs_ptr(output_new);
}
}

struct SelectOpIOInfo {
std::string name;
std::vector<std::string> dtypes;
std::vector<std::string> formats;
};

void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info,
mindspore::kernel::OpInfo *op_info_new) {
MS_EXCEPTION_IF_NULL(op_info_new);
auto op_seclect_json = OpSelectFormat();
if (!op_seclect_json.empty()) {
nlohmann::json json_obj = nlohmann::json::parse(op_seclect_json);
if (!json_obj.is_object()) {
MS_LOG(EXCEPTION) << "JsonStr is not an object, the jsonStr is:" << op_seclect_json;
}
std::vector<SelectOpIOInfo> inputs;
std::vector<SelectOpIOInfo> outputs;
for (const auto &item : json_obj.items()) {
const std::string &item_name = item.key();
bool is_input = (item_name.find(kPrefixInput) != std::string::npos);
bool is_output = (item_name.find(kPrefixOutput) != std::string::npos);
if (!is_input && !is_output) {
MS_LOG(EXCEPTION) << "op select ret json is error.";
}
if (is_input) {
SelectOpIOInfo select_input;
select_input.name = item.value().at(kName);
std::string input_dtype_item = item.value().at(kDtype);
select_input.dtypes = SplitStrToVec(input_dtype_item);
std::string input_format_item = item.value().at(kFormat);
select_input.formats = SplitStrToVec(input_format_item);
inputs.emplace_back(select_input);
} else if (is_output) {
SelectOpIOInfo select_output;
select_output.name = item.value().at(kName);
std::string input_dtype_item = item.value().at(kDtype);
select_output.dtypes = SplitStrToVec(input_dtype_item);
std::string input_format_item = item.value().at(kFormat);
select_output.formats = SplitStrToVec(input_format_item);
outputs.emplace_back(select_output);
}
}

if (op_info.inputs_ptr().size() != inputs.size() || op_info.outputs_ptr().size() != outputs.size()) {
MS_LOG(EXCEPTION) << "select format input/output size not equal, please check register.";
}

*op_info_new = op_info;
op_info_new->ClearInputs();
op_info_new->ClearOutputs();
for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) {
auto input_new = std::make_shared<OpIOInfo>();
CreateNewOpIOInfo(*op_info.inputs_ptr().at(i), inputs.at(i).dtypes, inputs.at(i).formats, input_new.get());
op_info_new->add_inputs_ptr(input_new);
}
for (size_t i = 0; i < op_info.outputs_ptr().size(); ++i) {
auto output_new = std::make_shared<OpIOInfo>();
CreateNewOpIOInfo(*op_info.outputs_ptr().at(i), outputs.at(i).dtypes, outputs.at(i).formats, output_new.get());
op_info_new->add_outputs_ptr(output_new);
}
}
}

void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info,
const std::vector<std::string> &support_dtype,
const std::vector<std::string> &support_format,
mindspore::kernel::OpIOInfo *op_io_info_new) {
MS_EXCEPTION_IF_NULL(op_io_info_new);
op_io_info_new->set_index(op_io_info.index());
op_io_info_new->set_name(op_io_info.name());
op_io_info_new->set_param_type(op_io_info.param_type());
op_io_info_new->set_need_compile(op_io_info.need_compile());
op_io_info_new->set_reshape_type(op_io_info.reshape_type());
op_io_info_new->set_shape(op_io_info.shape());
// dtype && format
op_io_info_new->set_dtypes(support_dtype);
op_io_info_new->set_formats(support_format);
}

void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) {
if (support_format.input_format.size() != support_format.output_format.size()) {
MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output("
<< support_format.output_format.size() << ") size not match.";
}
for (size_t i = 0; i < support_format.input_format.size(); ++i) {
auto input_items = support_format.input_format.at(i);
auto output_items = support_format.output_format.at(i);
std::string print_str = "[";
for (const auto &input : input_items) {
print_str.append(input);
print_str.append(", ");
}
print_str.append("] -->");
for (const auto &output : output_items) {
print_str.append(output);
print_str.append(", ");
}
MS_LOG(INFO) << "Support format: " << print_str;
}
}
} // namespace kernel
} // namespace mindspore

+ 0
- 492
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc View File

@@ -1,492 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ascend_backend_optimization.h"
#include <memory>
#include <string>
#include <set>
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ir_fission/bn_split.h"
#include "pre_activate/ascend/ir_fission/bn_grad_split.h"
#include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h"
#include "pre_activate/ascend/ir_fission/batch_norm_bert_fission.h"
#include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h"
#include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h"
#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h"
#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h"
#include "pre_activate/pass/communication_op_fusion.h"
#include "pre_activate/ascend/ir_fusion/square_sum_fusion.h"
#include "pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h"
#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h"
#include "pre_activate/ascend/ir_fusion/clip_by_value_fusion.h"
#include "pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h"
#include "pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h"
#include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h"
#include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h"
#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h"
#include "pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h"
#include "pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h"
#include "pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h"
#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h"
#include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h"
#include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h"
#include "pre_activate/ascend/ir_fusion/refresh_parameter_format.h"
#include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h"
#include "pre_activate/ascend/ir_fission/transdata_split.h"
#include "pre_activate/ascend/ir_fission/topk_split.h"
#include "pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h"
#include "pre_activate/ascend/ir_fusion/mul_add_fusion.h"
#include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h"
#include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h"
#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h"
#include "pre_activate/ascend/ir_fusion/derelu_fusion.h"
#include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h"
#include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h"
#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h"
#include "pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h"
#include "pre_activate/ascend/format_type/insert_trans_op.h"
#include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h"
#include "pre_activate/pass/getitem_tuple.h"
#include "pre_activate/pass/optimize_dependence.h"
#include "pre_activate/pass/erase_visit_attr.h"
#include "pre_activate/ascend/format_type/insert_cast.h"
#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
#include "pre_activate/pass/eliminate_redundant_op.h"
#include "pre_activate/pass/common_subexpression_elimination.h"
#include "pre_activate/pass/fuse_graph_kernel.h"
#include "pre_activate/pass/fuse_basic.h"
#include "pre_activate/pass/add_atomic_clean.h"
#include "pre_activate/ascend/format_type/merge_cast_to_op.h"
#include "pre_activate/ascend/format_type/check_consistency.h"
#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h"
#include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h"
#include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h"
#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
#include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h"
#include "pre_activate/ascend/format_type/insert_transdata_for_runop.h"
#include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h"
#include "pre_activate/ascend/ir_fission/addn_fission.h"
#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h"
#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h"
#include "pre_activate/ascend/ir_fission/split_fission.h"
#include "pre_activate/ascend/format_type/modify_ops_attrs.h"
#include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "debug/anf_ir_dump.h"
#include "debug/anf_ir_utils.h"

namespace mindspore {
namespace opt {
namespace {
void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
MS_EXCEPTION_IF_NULL(ir_fusion_pm);
ir_fusion_pm->AddPass(std::make_shared<BatchNormBertFission>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusion>());
ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusionV2>());
ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusionV3>());
ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>());
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond3>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond4>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond1>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond2>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond3>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>());
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>());
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
ir_fusion_pm->AddPass(std::make_shared<ClipByValueFusion>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond1Fusion>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond2Fusion>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond3Fusion>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond4Fusion>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond1>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond2>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond3>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond4>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond5>());
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>());
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>());
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<DereluFusion>());
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>());
ir_fusion_pm->AddPass(std::make_shared<SplitFission>());
ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
}
} // namespace

void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
auto data_layout_pm = std::make_shared<PassManager>("pynative_transop_pm");
data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>());
data_layout_pm->AddPass(std::make_shared<RunOpInsertTransData>());
data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
data_layout_pm->AddPass(std::make_shared<EliminateRedundantOp>());
data_layout_pm->AddPass(std::make_shared<OptimizeDependence>());
data_layout_pm->AddPass(std::make_shared<TransDataSplit>());
data_layout_pm->AddPass(std::make_shared<EraseVisitAttr>());
optimizer->AddPassManager(data_layout_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}

void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
MS_EXCEPTION_IF_NULL(optimizer);
auto common_process = std::make_shared<PassManager>("graph_kernel_common_process");
MS_EXCEPTION_IF_NULL(common_process);
common_process->AddPass(std::make_shared<ModifyOpAttrs>());
common_process->AddPass(std::make_shared<RemoveNoUseReshapeOp>());
optimizer->AddPassManager(common_process);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}

void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
auto data_layout_pm = std::make_shared<PassManager>("transop_pm");
data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>());
data_layout_pm->AddPass(std::make_shared<InsertTransOp>());
data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
data_layout_pm->AddPass(std::make_shared<RemoveReshapePair>());
data_layout_pm->AddPass(std::make_shared<EliminateRedundantOp>());
data_layout_pm->AddPass(std::make_shared<OptimizeDependence>());
data_layout_pm->AddPass(std::make_shared<TransDataSplit>());
data_layout_pm->AddPass(std::make_shared<EraseVisitAttr>());
optimizer->AddPassManager(data_layout_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}

void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
auto mixed_precision_pm = std::make_shared<PassManager>("cast_pm");
mixed_precision_pm->AddPass(std::make_shared<InsertCast>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>());
optimizer->AddPassManager(mixed_precision_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}

void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before" + "_graph_" +
std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
DumpIRProto(kernel_graph, "before_hwopt_" + std::to_string(kernel_graph->graph_id()));
}
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
if (context_ptr->ir_fusion_flag()) {
AddAscendBackendOptionalIRFusion(ir_fusion_pm.get());
}

if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) {
ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForGetNext>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
}
ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForHcclOp>());
optimizer->AddPassManager(ir_fusion_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
if (save_graphs) {
std::string file_path =
save_graphs_path + "/" + "hwopt_d_ir_fusion_after" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
}
}

void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->ir_fusion_flag()) {
MS_LOG(INFO) << "IRFusion is not enable, skip";
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before.ir";
DumpIR(file_path, kernel_graph);
}
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>());

optimizer->AddPassManager(ir_fusion_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after.ir";
DumpIR(file_path, kernel_graph);
}
}

void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path =
save_graphs_path + "/" + "hwopt_d_before" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
}
// data layout optimization
AscendDataLayout(kernel_graph);
// mixed precision optimization
AscendMixPrecision(kernel_graph);
// other optimization
auto optimizer = std::make_shared<GraphOptimizer>();
auto other_pm = std::make_shared<PassManager>("other_pm");
other_pm->AddPass(std::make_shared<AllReduceFusion>());
other_pm->AddPass(std::make_shared<AllGatherFusion>());
other_pm->AddPass(std::make_shared<ReduceScatterFusion>());
other_pm->AddPass(std::make_shared<BroadcastFusion>());
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
optimizer->AddPassManager(other_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
// buffer fusion
AscendBackendUBFusionOptimization(kernel_graph);

// other2 optimization
auto optimizer2 = std::make_shared<GraphOptimizer>();
auto other2_pm = std::make_shared<PassManager>("other2_pm");
other2_pm->AddPass(std::make_shared<GetitemTuple>());
other2_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) {
other2_pm->AddPass(std::make_shared<GetnextMemcpyElimination>());
}
other2_pm->AddPass(std::make_shared<CheckConsistency>());
optimizer2->AddPassManager(other2_pm);
(void)optimizer2->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();

if (save_graphs) {
std::string file_path =
save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph, true);
DumpIRProto(kernel_graph, "after_hwopt");
kernel_graph->DumpFuncGraph("hwopt_d_end");
}
}

void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
bool is_before_kernel_select) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_before_graph_" +
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
".ir";
DumpIR(file_path, kernel_graph);
}

// Fuse graph kernels with basic ops
FuseGraphKernel(kernel_graph, is_before_kernel_select);

if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" +
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
".ir";
DumpIR(file_path, kernel_graph, true);
}
}

void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
bool is_before_kernel_select) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_before_graph_" +
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
".ir";
DumpIR(file_path, kernel_graph, true);
}

// Fuse basic ops with basic ops
FuseBasic(kernel_graph, is_before_kernel_select);

if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_end_graph_" +
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
".ir";
DumpIR(file_path, kernel_graph, true);
}
}

void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_add_atomic_clean_before" + "_graph_" +
std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
}

AddAtomicClean(kernel_graph);

if (save_graphs) {
std::string file_path =
save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph, true);
}
}

void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->ir_fusion_flag()) {
MS_LOG(INFO) << "UBFusion is not enable, skip";
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path =
save_graphs_path + "/hwopt_d_ub_fusion_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
}
auto fusion_id_allocator = std::make_shared<FusionIdAllocator>();
MS_EXCEPTION_IF_NULL(fusion_id_allocator);
fusion_id_allocator->Init();
auto optimizer = std::make_shared<GraphOptimizer>();
auto ub_fusion_pm = std::make_shared<PassManager>("ub_fusion_pm");
ub_fusion_pm->AddPass(std::make_shared<Conv2DBackpropEltwiseEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<Conv2DBackpropEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<ConvBnReduceFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<ConvSingleInFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<ConvDoubleInFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<ReduceEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<SegmentEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<MultiOutputFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<UbPatternFusion>());
optimizer->AddPassManager(ub_fusion_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
if (save_graphs) {
std::string file_path =
save_graphs_path + "/hwopt_d_ub_fusion_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
}
}
} // namespace opt
} // namespace mindspore

+ 0
- 226
mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc View File

@@ -1,226 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h"
#include <utility>
#include <vector>
#include <memory>
#include <string>
#include "kernel/oplib/oplib.h"
#include "session/anf_runtime_algorithm.h"
#include "session/kernel_graph.h"
#include "pre_activate/common/helper.h"

namespace mindspore {
namespace opt {
namespace {
session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) {
session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0);
AnfNodePtr cur_node = kernel_with_index.first;
size_t cur_out_index = kernel_with_index.second;
MS_EXCEPTION_IF_NULL(cur_node);
if (cur_node->isa<CNode>()) {
auto cnode = cur_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::string op_name = AnfAlgo::GetCNodeName(cnode);
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE);
// deal ref op
if (op_info != nullptr && op_info->is_ref()) {
auto ref_infos = op_info->ref_infos();
if (ref_infos.count(cur_out_index) != 0) {
auto in_index = ref_infos.at(cur_out_index);
if (in_index > cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size()
<< ", ref info is " << cur_out_index;
}
AnfNodePtr next_node = cnode->input(in_index + 1);
return FindRefOriginNode(next_node);
}
}

// deal special (trans,cast,reshape) op
if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() ||
op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) {
AnfNodePtr next_node = cnode->input(1);
return FindRefOriginNode(next_node);
}
}

return kernel_with_index;
}

void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item,
const AnfNodePtr &final_node, size_t final_index,
const session::KernelWithIndex &origin_pair) {
// record the ref_pair
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
// if the final node is get item, means no trans or cast op is added, the final node is itself
// so add the pair for itself, because the get item will removed later
auto final_ref = (final_node == get_item ? cnode : final_node);
session::AnfWithOutIndex final_pair = std::make_pair(final_ref, final_index);
if (kernel_graph->IsInRefOutputMap(final_pair)) {
MS_LOG(EXCEPTION) << "ref_pair is already in ref map, node is " << final_ref->DebugString() << ", index is "
<< final_index;
}
MS_LOG(DEBUG) << "Add Ref pair, final {node ptr " << final_pair.first.get() << " , info is "
<< final_pair.first->DebugString() << " , index is " << final_pair.second << "}, origin {node ptr "
<< origin_pair.first.get() << ", info is " << origin_pair.first->DebugString() << " : index "
<< origin_pair.second << "}";
kernel_graph->AddRefCorrespondPairs(final_pair, origin_pair);
}

// if get_item is nullptr, the additional node will link to the cnode
// else the additional node will link to the get_item node (the get_item node link to cnode)
AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index,
size_t input_index, const AnfNodePtr &get_item) {
AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item);
size_t final_index = output_index;
AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index);
session::KernelWithIndex origin_pair;
origin_pair = FindRefOriginNode(input_node);
MS_EXCEPTION_IF_NULL(origin_pair.first);
if (!origin_pair.first->isa<Parameter>()) {
MS_LOG(WARNING) << "ref op origin node is not parameter";
}
MS_LOG(DEBUG) << "DealRefTransAndCast the node input index " << input_index << ", find origin op is "
<< origin_pair.first->DebugString() << ", index is " << origin_pair.second;
auto origin_format = AnfAlgo::GetOutputFormat(origin_pair.first, origin_pair.second);
auto origin_type = AnfAlgo::GetOutputDeviceDataType(origin_pair.first, origin_pair.second);
auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index);
auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index);
auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, output_index);
// insert trans
if (origin_format != cur_format && cur_shape.size() > 1) {
auto kernel_select = std::make_shared<KernelSelect>();
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(cur_format, origin_format, final_node);
final_index = 0;
MS_EXCEPTION_IF_NULL(final_node);
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
}
// insert cast
if (origin_type != cur_type) {
final_node =
AddCastOpNodeToGraph(func_graph, final_node, origin_format, cur_type, origin_type, cur_shape, cur_type);
MS_EXCEPTION_IF_NULL(final_node);
final_node->set_scope(cnode->scope());
final_index = 0;
MS_LOG(INFO) << "DealRefTransAndCast add cast op, op debug info is " << final_node->DebugString();
}
// add ref pair
AddRefPairToKernelGraph(func_graph, cnode, get_item, final_node, final_index, origin_pair);
// insert depend
if (origin_format != cur_format || origin_type != cur_type) {
std::vector<AnfNodePtr> depend_nodes{NewValueNode(prim::kPrimDepend), cnode, final_node};
final_node = func_graph->NewCNode(depend_nodes);
MS_LOG(INFO) << "DealRefTransAndCast add denpend, op debug info is " << final_node->DebugString();
}

return final_node;
}
AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) {
MS_EXCEPTION_IF_NULL(op_info);
auto ref_infos = op_info->ref_infos();
std::vector<AnfNodePtr> make_tuple_inputs;
AbstractBasePtrList abstract_list;
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
AnfNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index);
// deal with ref output
if (ref_infos.count(output_index) != 0) {
auto input_index = ref_infos.at(output_index);
final_node = AddAdditionalToRefOutput(func_graph, cnode, output_index, input_index, final_node);
}
MS_EXCEPTION_IF_NULL(final_node);
abstract_list.push_back(final_node->abstract());
make_tuple_inputs.push_back(final_node);
}
MS_EXCEPTION_IF_NULL(func_graph);
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple);
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
return make_tuple;
}

AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(op_info);
auto ref_infos = op_info->ref_infos();
for (const auto &ref_info : ref_infos) {
if (ref_info.second > cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size() << ", ref info is "
<< ref_info.second;
}
return AddAdditionalToRefOutput(func_graph, cnode, ref_info.first, ref_info.second, nullptr);
}
return nullptr;
}
} // namespace

const BaseRef DealRefTransAndCast::DefinePattern() const {
VarPtr V = std::make_shared<CondVar>(UnVisited);
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({V, Xs});
}

void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) {
auto input_size = AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 0; i < input_size; ++i) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i);
auto input_node = input_node_with_index.first;
MS_EXCEPTION_IF_NULL(input_node);
MS_LOG(INFO) << "origin node:" << input_node->fullname_with_scope();
AddRefPairToKernelGraph(func_graph, cnode, nullptr, cnode, i, input_node_with_index);
}
}
}

const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>()) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::IsRealCNodeKernel(cnode)) {
return nullptr;
}

DealBroadCastAsRef(graph, cnode);

auto op_name = AnfAlgo::GetCNodeName(cnode);
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE);
if (op_info == nullptr || !op_info->is_ref()) {
return nullptr;
}
if (op_info->is_ref()) {
auto type = cnode->Type();
MS_EXCEPTION_IF_NULL(type);
if (!type->isa<Tuple>()) {
return DealRefSigleOutput(graph, cnode, op_info);
} else {
return DealRefForMultipleOutput(graph, cnode, op_info);
}
}
return nullptr;
}
} // namespace opt
} // namespace mindspore

+ 0
- 64
mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc View File

@@ -1,64 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/pass/convert_const_input_to_attr.h"

#include <vector>
#include <string>
#include <unordered_map>
#include <memory>

#include "pre_activate/pass/const_input_to_attr_registry.h"
#include "pre_activate/common/helper.h"
#include "utils/utils.h"
#include "utils/context/ms_context.h"
#include "operator/ops.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/common_utils.h"

namespace mindspore {
namespace opt {
const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
return nullptr;
}
std::vector<AnfNodePtr> todos;
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
kernel::GetValidKernelNodes(sub_graph, &todos);
} else {
todos.push_back(node);
}

for (auto &t : todos) {
CNodePtr cnode = t->cast<CNodePtr>();
ConstInputToAttrInfoRegister reg;
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), &reg)) {
continue;
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() ||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) {
if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
continue;
}
}
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
}
return node;
}
} // namespace opt
} // namespace mindspore

Loading…
Cancel
Save