|
|
|
@@ -33,7 +33,6 @@ namespace opt { |
|
|
|
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; |
|
|
|
namespace { |
|
|
|
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; |
|
|
|
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW}; |
|
|
|
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, |
|
|
|
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) { |
|
|
|
std::vector<AnfNodePtr> trans_inputs; |
|
|
|
@@ -82,45 +81,18 @@ std::string InitDefaultFormat(const AnfNodePtr &node) { |
|
|
|
return default_format; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { |
|
|
|
AnfNodePtr trans_node = nullptr; |
|
|
|
CNodePtr trans_data = nullptr; |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
// Init |
|
|
|
std::string default_format = InitDefaultFormat(node); |
|
|
|
AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node; |
|
|
|
std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index); |
|
|
|
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format; |
|
|
|
std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) |
|
|
|
: AnfAlgo::GetOutputReshapeType(node, insert_index); |
|
|
|
auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) |
|
|
|
: AnfAlgo::GetOutputInferShape(input_node, insert_index); |
|
|
|
bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size()) |
|
|
|
: trans::IsNeedPadding(input_format, input_node_out_shape.size()); |
|
|
|
if (!need_padding) { |
|
|
|
// don't need padding insert transdata only |
|
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); |
|
|
|
trans_node = trans_data; |
|
|
|
} else if (is_insert_input) { |
|
|
|
// if need padding & is input need insert a transdata |
|
|
|
// reshape[padding shape] -> transdata[padding shape] -> node |
|
|
|
auto padding_shape = |
|
|
|
trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index)); |
|
|
|
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); |
|
|
|
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); |
|
|
|
trans_node = trans_data; |
|
|
|
trans_data->set_abstract(input_node->abstract()); |
|
|
|
} else { |
|
|
|
// if need padding & is output need insert a transdata |
|
|
|
// node -> transdata[padding shape] -> reshape[ori_shape] |
|
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); |
|
|
|
auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape); |
|
|
|
trans_node = reshape_node; |
|
|
|
void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(trans_node); |
|
|
|
auto real_input_node = AnfAlgo::VisitKernelWithReturnType(node, 0).first; |
|
|
|
if (!real_input_node->isa<CNode>()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto op_name = AnfAlgo::GetCNodeName(real_input_node); |
|
|
|
if (op_name == kBasicLSTMCellWeightGradOpName && AnfAlgo::GetCNodeName(trans_node) == prim::kPrimReshape->name()) { |
|
|
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(trans_node, 0); |
|
|
|
auto type = AnfAlgo::GetPrevNodeOutputInferDataType(trans_node, 0); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({type}, {{shape[0], shape[1]}}, node.get()); |
|
|
|
} |
|
|
|
// refresh the transdata's format to ori format & dst format |
|
|
|
RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis); |
|
|
|
return trans_node; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, |
|
|
|
@@ -161,15 +133,6 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An |
|
|
|
return node; |
|
|
|
} |
|
|
|
|
|
|
|
void ReFreshInferShape(const AnfNodePtr &node, const std::string &op_name) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (op_name == kBasicLSTMCellWeightGradOpName && AnfAlgo::GetCNodeName(node) == prim::kPrimReshape->name()) { |
|
|
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); |
|
|
|
auto type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({type}, {{shape[0], shape[1]}}, node.get()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const KernelSelectPtr &kernel_select) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
@@ -177,10 +140,6 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const |
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; |
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); |
|
|
|
size_t out_num = AnfAlgo::GetOutputTensorNum(node); |
|
|
|
std::string op_name; |
|
|
|
if (node->isa<CNode>()) { |
|
|
|
op_name = AnfAlgo::GetCNodeName(node); |
|
|
|
} |
|
|
|
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) { |
|
|
|
std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); |
|
|
|
if (output_format == kOpFormat_NC1KHKWHWC0) { |
|
|
|
@@ -191,7 +150,6 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const |
|
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); |
|
|
|
if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) { |
|
|
|
auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false); |
|
|
|
ReFreshInferShape(trans_op, op_name); |
|
|
|
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) { |
|
|
|
kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0); |
|
|
|
} |
|
|
|
@@ -205,6 +163,50 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const |
|
|
|
return make_tuple; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { |
|
|
|
AnfNodePtr trans_node = nullptr; |
|
|
|
CNodePtr trans_data = nullptr; |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
// Init |
|
|
|
std::string default_format = InitDefaultFormat(node); |
|
|
|
AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node; |
|
|
|
std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index); |
|
|
|
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format; |
|
|
|
std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) |
|
|
|
: AnfAlgo::GetOutputReshapeType(node, insert_index); |
|
|
|
auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) |
|
|
|
: AnfAlgo::GetOutputInferShape(input_node, insert_index); |
|
|
|
bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size()) |
|
|
|
: trans::IsNeedPadding(input_format, input_node_out_shape.size()); |
|
|
|
if (!need_padding) { |
|
|
|
// don't need padding insert transdata only |
|
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); |
|
|
|
trans_node = trans_data; |
|
|
|
} else if (is_insert_input) { |
|
|
|
// if need padding & is input need insert a transdata |
|
|
|
// reshape[padding shape] -> transdata[padding shape] -> node |
|
|
|
auto padding_shape = |
|
|
|
trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index)); |
|
|
|
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); |
|
|
|
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); |
|
|
|
trans_node = trans_data; |
|
|
|
trans_data->set_abstract(input_node->abstract()); |
|
|
|
} else { |
|
|
|
// if need padding & is output need insert a transdata |
|
|
|
// node -> transdata[padding shape] -> reshape[ori_shape] |
|
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); |
|
|
|
auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape); |
|
|
|
trans_node = reshape_node; |
|
|
|
} |
|
|
|
// refresh the transdata's format to ori format & dst format |
|
|
|
RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis); |
|
|
|
if (!is_insert_input) { |
|
|
|
ReFreshInferShape(trans_node, node); |
|
|
|
} |
|
|
|
return trans_node; |
|
|
|
} |
|
|
|
|
|
|
|
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, |
|
|
|
const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type, |
|
|
|
const TypeId &type_id) { |
|
|
|
|