|
|
@@ -14,7 +14,7 @@ |
|
|
* limitations under the License. |
|
|
* limitations under the License. |
|
|
*/ |
|
|
*/ |
|
|
|
|
|
|
|
|
#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" |
|
|
|
|
|
|
|
|
#include "backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h" |
|
|
#include <utility> |
|
|
#include <utility> |
|
|
#include <vector> |
|
|
#include <vector> |
|
|
#include <memory> |
|
|
#include <memory> |
|
|
@@ -26,7 +26,7 @@ |
|
|
|
|
|
|
|
|
namespace mindspore { |
|
|
namespace mindspore { |
|
|
namespace opt { |
|
|
namespace opt { |
|
|
session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr &node) const { |
|
|
|
|
|
|
|
|
session::KernelWithIndex DealRefAndSpiltUnSupportedTransdata::FindRefOriginNode(const AnfNodePtr &node) const { |
|
|
session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); |
|
|
session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); |
|
|
AnfNodePtr cur_node = kernel_with_index.first; |
|
|
AnfNodePtr cur_node = kernel_with_index.first; |
|
|
size_t cur_out_index = kernel_with_index.second; |
|
|
size_t cur_out_index = kernel_with_index.second; |
|
|
@@ -61,8 +61,9 @@ session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr |
|
|
return kernel_with_index; |
|
|
return kernel_with_index; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, |
|
|
|
|
|
const size_t output_index, const size_t input_index) const { |
|
|
|
|
|
|
|
|
void DealRefAndSpiltUnSupportedTransdata::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, |
|
|
|
|
|
const CNodePtr &cnode, const size_t output_index, |
|
|
|
|
|
const size_t input_index) const { |
|
|
// record the ref_pair |
|
|
// record the ref_pair |
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); |
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); |
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
@@ -71,10 +72,10 @@ void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_g |
|
|
kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index); |
|
|
kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, |
|
|
|
|
|
const AnfNodePtr &get_item, const AnfNodePtr &final_node, |
|
|
|
|
|
size_t final_index, |
|
|
|
|
|
const session::KernelWithIndex &origin_pair) const { |
|
|
|
|
|
|
|
|
void DealRefAndSpiltUnSupportedTransdata::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, |
|
|
|
|
|
const AnfNodePtr &get_item, |
|
|
|
|
|
const AnfNodePtr &final_node, size_t final_index, |
|
|
|
|
|
const session::KernelWithIndex &origin_pair) const { |
|
|
// record the ref_pair |
|
|
// record the ref_pair |
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); |
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); |
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
@@ -95,9 +96,10 @@ void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph |
|
|
|
|
|
|
|
|
// if get_item is nullptr, the additional node will link to the cnode |
|
|
// if get_item is nullptr, the additional node will link to the cnode |
|
|
// else the additional node will link to the get_item node (the get_item node link to cnode) |
|
|
// else the additional node will link to the get_item node (the get_item node link to cnode) |
|
|
CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, |
|
|
|
|
|
size_t output_index, size_t input_index, |
|
|
|
|
|
const CNodePtr &get_item) const { |
|
|
|
|
|
|
|
|
CNodePtr DealRefAndSpiltUnSupportedTransdata::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, |
|
|
|
|
|
const CNodePtr &cnode, size_t output_index, |
|
|
|
|
|
size_t input_index, |
|
|
|
|
|
const CNodePtr &get_item) const { |
|
|
CNodePtr final_node = (get_item == nullptr ? cnode : get_item); |
|
|
CNodePtr final_node = (get_item == nullptr ? cnode : get_item); |
|
|
bool need_refresh_ref_addr = false; |
|
|
bool need_refresh_ref_addr = false; |
|
|
size_t final_index = output_index; |
|
|
size_t final_index = output_index; |
|
|
@@ -149,8 +151,9 @@ CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_ |
|
|
return final_node; |
|
|
return final_node; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node, |
|
|
|
|
|
const CNodePtr &cnode, const FuncGraphPtr &func_graph) const { |
|
|
|
|
|
|
|
|
CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node, |
|
|
|
|
|
const CNodePtr &cnode, |
|
|
|
|
|
const FuncGraphPtr &func_graph) const { |
|
|
std::vector<AnfNodePtr> depend_nodes; |
|
|
std::vector<AnfNodePtr> depend_nodes; |
|
|
if (get_item != nullptr) { |
|
|
if (get_item != nullptr) { |
|
|
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node}; |
|
|
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node}; |
|
|
@@ -159,8 +162,8 @@ CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNo |
|
|
} |
|
|
} |
|
|
return func_graph->NewCNode(depend_nodes); |
|
|
return func_graph->NewCNode(depend_nodes); |
|
|
} |
|
|
} |
|
|
CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, |
|
|
|
|
|
const std::shared_ptr<kernel::OpInfo> &op_info) const { |
|
|
|
|
|
|
|
|
CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput( |
|
|
|
|
|
const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const { |
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
auto ref_infos = op_info->ref_infos(); |
|
|
auto ref_infos = op_info->ref_infos(); |
|
|
std::vector<AnfNodePtr> make_tuple_inputs; |
|
|
std::vector<AnfNodePtr> make_tuple_inputs; |
|
|
@@ -185,8 +188,8 @@ CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_ |
|
|
return make_tuple; |
|
|
return make_tuple; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, |
|
|
|
|
|
const std::shared_ptr<kernel::OpInfo> &op_info) const { |
|
|
|
|
|
|
|
|
CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, |
|
|
|
|
|
const std::shared_ptr<kernel::OpInfo> &op_info) const { |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
auto ref_infos = op_info->ref_infos(); |
|
|
auto ref_infos = op_info->ref_infos(); |
|
|
@@ -200,13 +203,14 @@ CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph, |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
const BaseRef DealRefTransAndCast::DefinePattern() const { |
|
|
|
|
|
|
|
|
const BaseRef DealRefAndSpiltUnSupportedTransdata::DefinePattern() const { |
|
|
VarPtr V = std::make_shared<CondVar>(UnVisited); |
|
|
VarPtr V = std::make_shared<CondVar>(UnVisited); |
|
|
VarPtr Xs = std::make_shared<SeqVar>(); |
|
|
VarPtr Xs = std::make_shared<SeqVar>(); |
|
|
return VectorRef({V, Xs}); |
|
|
return VectorRef({V, Xs}); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const { |
|
|
|
|
|
|
|
|
void DealRefAndSpiltUnSupportedTransdata::DealBroadCastAsRef(const FuncGraphPtr &func_graph, |
|
|
|
|
|
const CNodePtr &cnode) const { |
|
|
if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { |
|
|
if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { |
|
|
auto input_size = AnfAlgo::GetInputTensorNum(cnode); |
|
|
auto input_size = AnfAlgo::GetInputTensorNum(cnode); |
|
|
for (size_t i = 0; i < input_size; ++i) { |
|
|
for (size_t i = 0; i < input_size; ++i) { |
|
|
@@ -219,8 +223,8 @@ void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, con |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, |
|
|
|
|
|
const EquivPtr &) const { |
|
|
|
|
|
|
|
|
const AnfNodePtr DealRefAndSpiltUnSupportedTransdata::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, |
|
|
|
|
|
const EquivPtr &) const { |
|
|
if (node == nullptr || !node->isa<CNode>()) { |
|
|
if (node == nullptr || !node->isa<CNode>()) { |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
@@ -250,11 +254,12 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, |
|
|
|
|
|
const CNodePtr &cnode) const { |
|
|
|
|
|
|
|
|
CNodePtr DealRefAndSpiltUnSupportedTransdata::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, |
|
|
|
|
|
const CNodePtr &cnode) const { |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode); |
|
|
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode); |
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
|
|
// When the input and output format is only one special format just need to be splited into transpose and transdata |
|
|
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() || |
|
|
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() || |
|
|
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) { |
|
|
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) { |
|
|
if (IsFormatInvaild(cnode)) { |
|
|
if (IsFormatInvaild(cnode)) { |
|
|
@@ -262,6 +267,8 @@ CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &f |
|
|
} |
|
|
} |
|
|
return cnode; |
|
|
return cnode; |
|
|
} |
|
|
} |
|
|
|
|
|
// When input and output format are all special format |
|
|
|
|
|
// the node should be splited to two transdata connected by default format |
|
|
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); |
|
|
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); |
|
|
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); |
|
|
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); |
|
|
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT}); |
|
|
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT}); |
|
|
@@ -273,6 +280,8 @@ CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &f |
|
|
next_trans_node->set_abstract(cnode->abstract()); |
|
|
next_trans_node->set_abstract(cnode->abstract()); |
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get()); |
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get()); |
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get()); |
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get()); |
|
|
|
|
|
RefreshKernelBuildInfo(AnfAlgo::GetInputFormat(cnode, 0), kOpFormat_DEFAULT, cnode); |
|
|
|
|
|
RefreshKernelBuildInfo(kOpFormat_DEFAULT, AnfAlgo::GetOutputFormat(next_trans_node, 0), next_trans_node); |
|
|
if (IsFormatInvaild(cnode)) { |
|
|
if (IsFormatInvaild(cnode)) { |
|
|
auto after_split_node = DoSplit(func_graph, cnode); |
|
|
auto after_split_node = DoSplit(func_graph, cnode); |
|
|
AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0); |
|
|
AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0); |