|
|
|
@@ -31,54 +31,6 @@ namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; |
|
|
|
namespace { |
|
|
|
kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, |
|
|
|
const AnfNodePtr &node, const TypeId device_type, |
|
|
|
const kernel::KernelBuildInfo &ori_build_info) { |
|
|
|
KernelBuildInfoBuilder builder; |
|
|
|
builder.SetInputsFormat({input_format}); |
|
|
|
builder.SetOutputsFormat({output_format}); |
|
|
|
builder.SetInputsDeviceType({device_type}); |
|
|
|
builder.SetOutputsDeviceType({device_type}); |
|
|
|
builder.SetKernelType(ori_build_info.kernel_type()); |
|
|
|
builder.SetFusionType(ori_build_info.fusion_type()); |
|
|
|
builder.SetProcessor(ori_build_info.processor()); |
|
|
|
return builder.Build(); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, |
|
|
|
const bool need_padding, const std::string &op_name) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
|
std::vector<AnfNodePtr> trans_inputs; |
|
|
|
auto prim = std::make_shared<Primitive>(op_name); |
|
|
|
trans_inputs.push_back(NewValueNode(prim)); |
|
|
|
trans_inputs.push_back(input); |
|
|
|
CNodePtr trans_node = func_graph->NewCNode(trans_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(trans_node); |
|
|
|
std::vector<kernel::Axis> padding_axis; |
|
|
|
padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); |
|
|
|
if (need_padding) { |
|
|
|
// if need padding we should set the transdata node's shape to the padding shape |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, |
|
|
|
{trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, |
|
|
|
trans_node.get()); |
|
|
|
} else { |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, |
|
|
|
{AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); |
|
|
|
} |
|
|
|
// special handle for ut |
|
|
|
if (trans_node->kernel_info() == nullptr) { |
|
|
|
auto kernel_info = std::make_shared<device::KernelInfo>(); |
|
|
|
trans_node->set_kernel_info(kernel_info); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_select); |
|
|
|
kernel_select->SelectKernel(trans_node); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); |
|
|
|
MS_EXCEPTION_IF_NULL(trans_node); |
|
|
|
trans_node->set_scope(input->scope()); |
|
|
|
return trans_node; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, |
|
|
|
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) { |
|
|
|
std::vector<AnfNodePtr> trans_inputs; |
|
|
|
@@ -94,6 +46,58 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i |
|
|
|
return reshape; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { |
|
|
|
AnfNodePtr trans_node = nullptr; |
|
|
|
AnfNodePtr input_node = node; |
|
|
|
CNodePtr trans_data = nullptr; |
|
|
|
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); |
|
|
|
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; |
|
|
|
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); |
|
|
|
std::vector<kernel::Axis> padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
// if insert transdata for input we need to change the input |
|
|
|
if (is_insert_input) { |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); |
|
|
|
dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); |
|
|
|
input_node = AnfAlgo::GetInputNode(cnode, insert_index); |
|
|
|
padding_axis = AnfAlgo::GetInputReshapeType(node, 0); |
|
|
|
} |
|
|
|
bool need_padding = false; |
|
|
|
if (is_insert_input) { |
|
|
|
need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); |
|
|
|
} else { |
|
|
|
need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); |
|
|
|
} |
|
|
|
if (!need_padding) { |
|
|
|
// don't need padding insert transdata only |
|
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); |
|
|
|
trans_node = trans_data; |
|
|
|
} else if (is_insert_input) { |
|
|
|
// if need padding & is input need insert a transdata |
|
|
|
// reshape[padding shape] -> transdata[padding shape] -> node |
|
|
|
auto padding_shape = |
|
|
|
trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); |
|
|
|
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); |
|
|
|
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); |
|
|
|
trans_node = trans_data; |
|
|
|
} else { |
|
|
|
// if need padding & is output need insert a transdata |
|
|
|
// node -> transdata[padding shape] -> reshape[ori_shape] |
|
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); |
|
|
|
auto reshape_node = |
|
|
|
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); |
|
|
|
trans_node = reshape_node; |
|
|
|
} |
|
|
|
// refresh the transdata's format to ori format & dst format |
|
|
|
RefreshKernelBuildInfo(input_format, dst_format, dtype, trans_data, padding_axis); |
|
|
|
return trans_node; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, |
|
|
|
const KernelSelectPtr &kernel_select) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
@@ -111,13 +115,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & |
|
|
|
<< "when inserting the transdata node " << node->DebugString(); |
|
|
|
} |
|
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); |
|
|
|
std::string origin_format = kOpFormat_DEFAULT; |
|
|
|
std::string dest_format = AnfAlgo::GetInputFormat(node, index); |
|
|
|
if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { |
|
|
|
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) |
|
|
|
<< " To DefaultFormat , index: " << index; |
|
|
|
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, origin_format, dest_format, kTransDataOpName, |
|
|
|
true); |
|
|
|
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); |
|
|
|
} |
|
|
|
return input_node; |
|
|
|
} |
|
|
|
@@ -131,12 +133,9 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An |
|
|
|
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " |
|
|
|
<< node->DebugString(); |
|
|
|
} |
|
|
|
std::string origin_format = output_format; |
|
|
|
std::string dest_format = kOpFormat_DEFAULT; |
|
|
|
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { |
|
|
|
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; |
|
|
|
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, origin_format, dest_format, kTransDataOpName, |
|
|
|
false); |
|
|
|
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false); |
|
|
|
} |
|
|
|
return node; |
|
|
|
} |
|
|
|
@@ -155,10 +154,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const |
|
|
|
} |
|
|
|
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); |
|
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); |
|
|
|
std::string dest_format = kOpFormat_DEFAULT; |
|
|
|
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { |
|
|
|
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, output_format, |
|
|
|
dest_format, kTransDataOpName, false)); |
|
|
|
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false)); |
|
|
|
} else { |
|
|
|
// No need insert trans op. |
|
|
|
make_tuple_inputs.push_back(tuple_getitem); |
|
|
|
@@ -168,62 +165,54 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const |
|
|
|
return make_tuple; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const KernelSelectPtr &kernel_select, size_t insert_index, |
|
|
|
const std::string &origin_format, const std::string &dest_format, |
|
|
|
const std::string &op_name, bool is_insert_input) { |
|
|
|
AnfNodePtr trans_node = nullptr; |
|
|
|
AnfNodePtr input_node = node; |
|
|
|
AnfNodePtr trans_data = nullptr; |
|
|
|
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (origin_format.empty() || dest_format.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format; |
|
|
|
} |
|
|
|
// if insert transdata for input we need to change the input |
|
|
|
if (is_insert_input) { |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
input_node = AnfAlgo::GetInputNode(cnode, insert_index); |
|
|
|
} |
|
|
|
bool need_padding = false; |
|
|
|
if (is_insert_input) { |
|
|
|
need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && |
|
|
|
op_name == kTransDataOpName); |
|
|
|
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, |
|
|
|
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) { |
|
|
|
MS_EXCEPTION_IF_NULL(trans_data); |
|
|
|
MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); |
|
|
|
auto ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); |
|
|
|
KernelBuildInfoBuilder builder; |
|
|
|
builder.SetInputsFormat({input_format}); |
|
|
|
builder.SetInputReshapeType({reshape_type}); |
|
|
|
builder.SetInputReshapeType({reshape_type}); |
|
|
|
builder.SetOutputsFormat({output_format}); |
|
|
|
builder.SetInputsDeviceType({device_type}); |
|
|
|
builder.SetOutputsDeviceType({device_type}); |
|
|
|
builder.SetKernelType(ori_build_info->kernel_type()); |
|
|
|
builder.SetFusionType(ori_build_info->fusion_type()); |
|
|
|
builder.SetProcessor(ori_build_info->processor()); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), trans_data.get()); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, |
|
|
|
const bool need_padding, const std::string &op_name) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
|
std::vector<AnfNodePtr> trans_inputs; |
|
|
|
auto prim = std::make_shared<Primitive>(op_name); |
|
|
|
trans_inputs.push_back(NewValueNode(prim)); |
|
|
|
trans_inputs.push_back(input); |
|
|
|
CNodePtr trans_node = func_graph->NewCNode(trans_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(trans_node); |
|
|
|
auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); |
|
|
|
if (need_padding) { |
|
|
|
// if need padding we should set the transdata node's shape to the padding shape |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, |
|
|
|
{trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, |
|
|
|
trans_node.get()); |
|
|
|
} else { |
|
|
|
need_padding = (trans::IsNeedPadding(origin_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && |
|
|
|
op_name == kTransDataOpName); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, |
|
|
|
{AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); |
|
|
|
} |
|
|
|
if (!need_padding) { |
|
|
|
// don't need padding insert transdata only |
|
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); |
|
|
|
trans_node = trans_data; |
|
|
|
} else if (is_insert_input) { |
|
|
|
// if need padding & is input need insert a transdata |
|
|
|
// reshape[padding shape] -> transdata[padding shape] -> node |
|
|
|
auto padding_shape = |
|
|
|
trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); |
|
|
|
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); |
|
|
|
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, op_name); |
|
|
|
trans_node = trans_data; |
|
|
|
} else { |
|
|
|
// if need padding & is output need insert a transdata |
|
|
|
// node -> transdata[padding shape] -> reshape[ori_shape] |
|
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); |
|
|
|
auto reshape_node = |
|
|
|
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); |
|
|
|
trans_node = reshape_node; |
|
|
|
// special handle for ut |
|
|
|
if (trans_node->kernel_info() == nullptr) { |
|
|
|
auto kernel_info = std::make_shared<device::KernelInfo>(); |
|
|
|
trans_node->set_kernel_info(kernel_info); |
|
|
|
} |
|
|
|
// refresh the transdata's format to ori format & dst format |
|
|
|
MS_EXCEPTION_IF_NULL(trans_data); |
|
|
|
MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); |
|
|
|
auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); |
|
|
|
auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_select); |
|
|
|
kernel_select->SelectKernel(trans_node); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); |
|
|
|
MS_EXCEPTION_IF_NULL(trans_node); |
|
|
|
trans_node->set_scope(input->scope()); |
|
|
|
return trans_node; |
|
|
|
} |
|
|
|
|
|
|
|
|