|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -18,6 +18,8 @@ |
|
|
|
#include <memory> |
|
|
|
#include <vector> |
|
|
|
#include <string> |
|
|
|
#include <algorithm> |
|
|
|
#include <unordered_set> |
|
|
|
#include "base/core_ops.h" |
|
|
|
#include "utils/utils.h" |
|
|
|
#include "utils/log_adapter.h" |
|
|
|
@@ -27,78 +29,111 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
bool CanReorder(const FuncGraphManagerPtr &mng, const CNodePtr &transdata_node, const CNodePtr &cast_node) { |
|
|
|
auto transdata_input_type = AnfAlgo::GetInputDeviceDataType(transdata_node, 0); |
|
|
|
auto transdata_output_type = AnfAlgo::GetOutputDeviceDataType(transdata_node, 0); |
|
|
|
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_node, 0); |
|
|
|
auto cast_output_type = AnfAlgo::GetOutputDeviceDataType(cast_node, 0); |
|
|
|
// Conditions of reordering transdata_cast to cast_transdata: |
|
|
|
// 1) current transdata is only used by cast |
|
|
|
// 2) transdata works on float32 (transdata supports float16/float32; |
|
|
|
// transdata performances better on float16 due to less data to process) |
|
|
|
// 3) cast works on float32 -> float16 |
|
|
|
if (mng->node_users()[transdata_node].size() == 1 && transdata_input_type == kNumberTypeFloat32 && |
|
|
|
transdata_output_type == transdata_input_type && cast_input_type == transdata_output_type && |
|
|
|
cast_output_type == kNumberTypeFloat16) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
bool IsTypeInsensitive(const CNodePtr &node) { |
|
|
|
// Nodes that will change the input data type will not seen as type insensitive nodes. |
|
|
|
static std::unordered_set<PrimitivePtr> type_insensitive_op_list{ |
|
|
|
prim::KPrimTransData, prim::kPrimTranspose, prim::kPrimExpandDims, prim::kPrimReshape, |
|
|
|
prim::kPrimSqueeze, prim::kPrimTile, prim::kPrimNeg, prim::kPrimRelu, |
|
|
|
prim::kPrimMaximum, prim::kPrimMinimum, prim::kPrimSelect}; |
|
|
|
|
|
|
|
return std::any_of(type_insensitive_op_list.begin(), type_insensitive_op_list.end(), |
|
|
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); |
|
|
|
} |
|
|
|
|
|
|
|
enum CastType { CAST_UP, CAST_DOWN, CAST_OTHER }; |
|
|
|
CastType GetCastType(const CNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (!IsPrimitiveCNode(node, prim::kPrimCast)) { |
|
|
|
MS_LOG(EXCEPTION) << "Only process for Cast!"; |
|
|
|
} |
|
|
|
TypeId input_type = AnfAlgo::GetInputDeviceDataType(node, 0); |
|
|
|
TypeId output_type = AnfAlgo::GetOutputDeviceDataType(node, 0); |
|
|
|
|
|
|
|
if (input_type == kNumberTypeFloat16 && output_type == kNumberTypeFloat32) { |
|
|
|
return CAST_UP; |
|
|
|
} |
|
|
|
|
|
|
|
if (input_type == kNumberTypeFloat32 && output_type == kNumberTypeFloat16) { |
|
|
|
return CAST_DOWN; |
|
|
|
} |
|
|
|
|
|
|
|
return CAST_OTHER; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> GetOpDataInputIndexes(const CNodePtr &node) { |
|
|
|
std::vector<size_t> op_input_indexes; |
|
|
|
if (node == nullptr || !IsTypeInsensitive(node)) { |
|
|
|
return op_input_indexes; |
|
|
|
} |
|
|
|
|
|
|
|
// Data input index starts from 0. |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimMaximum) || IsPrimitiveCNode(node, prim::kPrimMinimum)) { |
|
|
|
op_input_indexes = {0, 1}; |
|
|
|
} else if (IsPrimitiveCNode(node, prim::kPrimSelect)) { |
|
|
|
op_input_indexes = {1, 2}; |
|
|
|
} else { |
|
|
|
op_input_indexes = {0}; |
|
|
|
} |
|
|
|
return op_input_indexes; |
|
|
|
} |
|
|
|
|
|
|
|
bool CheckInputTypeConsistent(const CNodePtr &node, const std::vector<size_t> &check_indexes, const TypeId &base_type) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
|
|
// node's inputs at check_indexes should be of type base_type |
|
|
|
for (const auto &index : check_indexes) { |
|
|
|
if (AnfAlgo::GetInputDeviceDataType(node, index) != base_type) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void SetNodeInfo(const CNodePtr &transdata_node, const CNodePtr &cast_node, const CNodePtr &node) { |
|
|
|
// Initial |
|
|
|
// TransData: (type0, format0) -> (type0, format1) |
|
|
|
// Cast: (type0, format1) -> (type1, format1) |
|
|
|
// After reorder |
|
|
|
// Cast: (type0, format0) -> (type1, format0) |
|
|
|
// TransData: (type1, format0) -> (type1, format1) |
|
|
|
auto type0 = AnfAlgo::GetInputDeviceDataType(transdata_node, 0); |
|
|
|
auto type1 = AnfAlgo::GetOutputDeviceDataType(cast_node, 0); |
|
|
|
auto format0 = AnfAlgo::GetInputFormat(transdata_node, 0); |
|
|
|
auto format1 = AnfAlgo::GetOutputFormat(transdata_node, 0); |
|
|
|
|
|
|
|
auto abstract = transdata_node->abstract(); |
|
|
|
auto scope = cast_node->scope(); |
|
|
|
void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const TypeId &node_type) { |
|
|
|
MS_EXCEPTION_IF_NULL(orig_node); |
|
|
|
MS_EXCEPTION_IF_NULL(new_node); |
|
|
|
|
|
|
|
auto node_name = AnfAlgo::GetCNodeName(new_node); |
|
|
|
auto orig_node_name = AnfAlgo::GetCNodeName(orig_node); |
|
|
|
if (orig_node_name != node_name) { |
|
|
|
MS_LOG(EXCEPTION) << "Can not process on different nodes " << orig_node_name << " and " << node_name; |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr new_abstract{nullptr}; |
|
|
|
std::vector<std::string> inputs_format; |
|
|
|
std::vector<std::string> outputs_format; |
|
|
|
std::vector<TypeId> inputs_device_type; |
|
|
|
std::vector<TypeId> outputs_device_type; |
|
|
|
auto kernel_type = AnfAlgo::GetKernelType(cast_node); |
|
|
|
auto op_pattern = AnfAlgo::GetOpPattern(cast_node); |
|
|
|
auto fusion_type = AnfAlgo::GetFusionType(cast_node); |
|
|
|
auto processor = AnfAlgo::GetProcessor(cast_node); |
|
|
|
std::vector<TypeId> outputs_device_type{node_type}; |
|
|
|
KernelType kernel_type{AnfAlgo::GetKernelType(orig_node)}; |
|
|
|
kernel::OpPattern op_pattern{AnfAlgo::GetOpPattern(orig_node)}; |
|
|
|
kernel::FusionType fusion_type{AnfAlgo::GetFusionType(orig_node)}; |
|
|
|
kernel::Processor processor{AnfAlgo::GetProcessor(orig_node)}; |
|
|
|
|
|
|
|
auto node_name = AnfAlgo::GetCNodeName(node); |
|
|
|
auto node_data_inputs_num = AnfAlgo::GetInputNum(new_node); |
|
|
|
for (size_t i = 0; i < node_data_inputs_num; ++i) { |
|
|
|
auto node_input = AnfAlgo::GetInputNode(new_node, i); |
|
|
|
auto node_input_format = AnfAlgo::GetOutputFormat(node_input, 0); |
|
|
|
auto node_input_type = AnfAlgo::GetOutputDeviceDataType(node_input, 0); |
|
|
|
inputs_format.push_back(node_input_format); |
|
|
|
inputs_device_type.push_back(node_input_type); |
|
|
|
} |
|
|
|
if (node_name == "Cast") { |
|
|
|
inputs_format.push_back(format0); |
|
|
|
outputs_format.push_back(format0); |
|
|
|
inputs_device_type.push_back(type0); |
|
|
|
outputs_device_type.push_back(type1); |
|
|
|
// Set attrs |
|
|
|
AnfAlgo::CopyNodeAttrs(cast_node, node); |
|
|
|
} else if (node_name == "TransData") { |
|
|
|
abstract = cast_node->abstract(); |
|
|
|
scope = transdata_node->scope(); |
|
|
|
inputs_format.push_back(format0); |
|
|
|
outputs_format.push_back(format1); |
|
|
|
inputs_device_type.push_back(type1); |
|
|
|
outputs_device_type.push_back(type1); |
|
|
|
kernel_type = AnfAlgo::GetKernelType(transdata_node); |
|
|
|
op_pattern = AnfAlgo::GetOpPattern(transdata_node); |
|
|
|
fusion_type = AnfAlgo::GetFusionType(transdata_node); |
|
|
|
processor = AnfAlgo::GetProcessor(transdata_node); |
|
|
|
// Set attrs |
|
|
|
AnfAlgo::CopyNodeAttrs(transdata_node, node); |
|
|
|
auto node_input = AnfAlgo::GetInputNode(new_node, 0); |
|
|
|
new_abstract = |
|
|
|
std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_type), node_input->abstract()->BuildShape()); |
|
|
|
outputs_format.push_back(AnfAlgo::GetOutputFormat(node_input, 0)); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Node must be Cast or TransData"; |
|
|
|
new_abstract = |
|
|
|
std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_type), orig_node->abstract()->BuildShape()); |
|
|
|
outputs_format.push_back(AnfAlgo::GetOutputFormat(orig_node, 0)); |
|
|
|
} |
|
|
|
|
|
|
|
// Set abstract info |
|
|
|
node->set_abstract(abstract); |
|
|
|
// Set scope info |
|
|
|
node->set_scope(scope); |
|
|
|
new_node->set_abstract(new_abstract); |
|
|
|
// Set attrs |
|
|
|
AnfAlgo::CopyNodeAttrs(orig_node, new_node); |
|
|
|
// Set kernel build info |
|
|
|
node->set_kernel_info(std::make_shared<device::KernelInfo>()); |
|
|
|
new_node->set_kernel_info(std::make_shared<device::KernelInfo>()); |
|
|
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder; |
|
|
|
info_builder.SetInputsFormat(inputs_format); |
|
|
|
info_builder.SetInputsDeviceType(inputs_device_type); |
|
|
|
@@ -108,10 +143,141 @@ void SetNodeInfo(const CNodePtr &transdata_node, const CNodePtr &cast_node, cons |
|
|
|
info_builder.SetOpPattern(op_pattern); |
|
|
|
info_builder.SetFusionType(fusion_type); |
|
|
|
info_builder.SetProcessor(processor); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), node.get()); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), new_node.get()); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::vector<size_t> &indexes, |
|
|
|
const std::vector<AnfNodePtr> &new_input_at_indexes, |
|
|
|
std::vector<AnfNodePtr> *new_inputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(new_inputs); |
|
|
|
if (indexes.size() != new_input_at_indexes.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size " |
|
|
|
<< new_input_at_indexes.size(); |
|
|
|
} |
|
|
|
if (!new_inputs->empty()) { |
|
|
|
new_inputs->resize(0); |
|
|
|
} |
|
|
|
|
|
|
|
// node's inputs at indexes change to new_input_at_indexes |
|
|
|
std::unordered_set<size_t> indexes_set(indexes.begin(), indexes.end()); |
|
|
|
auto node_inputs_num = node->size(); |
|
|
|
size_t idx = 0; |
|
|
|
for (size_t i = 0; i < node_inputs_num; ++i) { |
|
|
|
if (indexes_set.find(i) == indexes_set.end()) { |
|
|
|
new_inputs->push_back(node->input(i)); |
|
|
|
} else { |
|
|
|
new_inputs->push_back(new_input_at_indexes[idx++]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool ReorderTransDataCast(const FuncGraphPtr &func_graph) { |
|
|
|
bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng, |
|
|
|
const CNodePtr &node) { |
|
|
|
// Limitation: Current cast node is CAST_DOWN. |
|
|
|
if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto node_input = AnfAlgo::GetInputNode(node, 0); |
|
|
|
auto type_insens_node = node_input->cast<CNodePtr>(); |
|
|
|
// Limitation: |
|
|
|
// Find type insensitive node before cast node. |
|
|
|
// Type insensitive node is only used by current cast node. |
|
|
|
if (type_insens_node == nullptr || !IsTypeInsensitive(type_insens_node) || |
|
|
|
mng->node_users()[type_insens_node].size() > 1) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(node, 0); |
|
|
|
auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(node, 0); |
|
|
|
auto op_input_indexes = GetOpDataInputIndexes(type_insens_node); |
|
|
|
// Limitation: Type insensitive node's inputs have same data type. |
|
|
|
if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, cast_input_type)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> new_cast_nodes; |
|
|
|
for (const auto &index : op_input_indexes) { |
|
|
|
auto new_cast_node = |
|
|
|
func_graph->NewCNode({NewValueNode(prim::kPrimCast), AnfAlgo::GetInputNode(type_insens_node, index)}); |
|
|
|
SetNodeInfo(node, new_cast_node, cast_out_type); |
|
|
|
new_cast_nodes.push_back(new_cast_node); |
|
|
|
} |
|
|
|
|
|
|
|
std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(), |
|
|
|
[](const size_t &idx) { return idx + 1; }); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> type_insens_node_new_inputs; |
|
|
|
SetTypeInsensitiveNodeInputs(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_node_new_inputs); |
|
|
|
auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs); |
|
|
|
SetNodeInfo(type_insens_node, new_type_insens_node, cast_out_type); |
|
|
|
|
|
|
|
(void)mng->Replace(node, new_type_insens_node); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng, |
|
|
|
const CNodePtr &node) { |
|
|
|
if (!IsTypeInsensitive(node)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// Limitation: |
|
|
|
// Certain inputs of type insensitive node are cast node. |
|
|
|
// Cast nodes are CAST_UP. |
|
|
|
// All these cast nodes are only used by current type insensitive node. |
|
|
|
std::vector<CNodePtr> cast_nodes; |
|
|
|
std::vector<AnfNodePtr> cast_input_nodes; |
|
|
|
auto op_input_indexes = GetOpDataInputIndexes(node); |
|
|
|
for (const auto &index : op_input_indexes) { |
|
|
|
auto node_input = AnfAlgo::GetInputNode(node, index); |
|
|
|
auto cast_node = node_input->cast<CNodePtr>(); |
|
|
|
if (cast_node != nullptr && IsPrimitiveCNode(cast_node, prim::kPrimCast) && GetCastType(cast_node) == CAST_UP && |
|
|
|
mng->node_users()[cast_node].size() == 1) { |
|
|
|
cast_nodes.push_back(cast_node); |
|
|
|
cast_input_nodes.push_back(AnfAlgo::GetInputNode(cast_node, 0)); |
|
|
|
} |
|
|
|
} |
|
|
|
if (cast_nodes.empty() || cast_nodes.size() != op_input_indexes.size()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0); |
|
|
|
auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0); |
|
|
|
// Limitation: All these cast nodes cast same type to another type. |
|
|
|
if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&cast_input_type](const CNodePtr &cast_node) { |
|
|
|
return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == cast_input_type; |
|
|
|
})) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
// Limitation: Type insensitive node's inputs have same data type. |
|
|
|
if (!CheckInputTypeConsistent(node, op_input_indexes, cast_out_type)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(), |
|
|
|
[](const size_t &idx) { return idx + 1; }); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> type_insens_node_new_inputs; |
|
|
|
SetTypeInsensitiveNodeInputs(node, op_input_indexes, cast_input_nodes, &type_insens_node_new_inputs); |
|
|
|
auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs); |
|
|
|
SetNodeInfo(node, new_type_insens_node, cast_input_type); |
|
|
|
|
|
|
|
auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), new_type_insens_node}); |
|
|
|
SetNodeInfo(cast_nodes[0], new_cast_node, cast_out_type); |
|
|
|
|
|
|
|
(void)mng->Replace(node, new_cast_node); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool ReorderOps::ReorderCastTypeInsensitive(const FuncGraphPtr &func_graph) { |
|
|
|
// Reorder cast node and type insensitive node in graph kernel sub-graph, this function has several limitations, |
|
|
|
// see the comments that start will "Limitation:" in this file. |
|
|
|
// Limitation: Assuming the type insensitive node will not change the type of input nodes, otherwise it can be seen |
|
|
|
// as another cast node in some sense, such as LessEqual operator, which performs on two inputs and output a |
|
|
|
// a boolean result. |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
auto mng = func_graph->manager(); |
|
|
|
if (mng == nullptr) { |
|
|
|
@@ -121,40 +287,52 @@ bool ReorderTransDataCast(const FuncGraphPtr &func_graph) { |
|
|
|
bool changed = false; |
|
|
|
auto todos = TopoSort(func_graph->get_return()); |
|
|
|
for (const auto &anf_node : todos) { |
|
|
|
// Find cast node. |
|
|
|
auto cast_node = anf_node->cast<CNodePtr>(); |
|
|
|
if (cast_node == nullptr || !AnfAlgo::CheckPrimitiveType(cast_node, prim::kPrimCast)) { |
|
|
|
auto node = anf_node->cast<CNodePtr>(); |
|
|
|
if (node == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
// Find transdata node before cast node. |
|
|
|
auto cast_input = AnfAlgo::GetInputNode(cast_node, 0); |
|
|
|
auto transdata_node = cast_input->cast<CNodePtr>(); |
|
|
|
if (transdata_node == nullptr || !AnfAlgo::CheckPrimitiveType(transdata_node, prim::KPrimTransData)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
// Reorder transdata_cast to cast_transdata if possible. |
|
|
|
if (!CanReorder(mng, transdata_node, cast_node)) { |
|
|
|
continue; |
|
|
|
if (IsTypeInsensitive(node)) { |
|
|
|
// Reorder pattern 1: CastUp-TypeInsensitive --> TypeInsensitive-CastUp |
|
|
|
changed = ReorderCastUpTypeInsensitive(func_graph, mng, node) || changed; |
|
|
|
} else if (IsPrimitiveCNode(node, prim::kPrimCast)) { |
|
|
|
// Reorder pattern 2: TypeInsensitive-CastDown --> CastDown-TypeInsensitive |
|
|
|
changed = ReorderTypeInsensitiveCastDown(func_graph, mng, node) || changed; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Reorder " << transdata_node->fullname_with_scope() << ", " << cast_node->fullname_with_scope(); |
|
|
|
return changed; |
|
|
|
} |
|
|
|
|
|
|
|
auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), transdata_node->inputs()[1]}); |
|
|
|
SetNodeInfo(transdata_node, cast_node, new_cast_node); |
|
|
|
bool ReorderOps::Run(const FuncGraphPtr &func_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
auto mng = func_graph->manager(); |
|
|
|
if (mng == nullptr) { |
|
|
|
mng = Manage(func_graph, true); |
|
|
|
func_graph->set_manager(mng); |
|
|
|
} |
|
|
|
|
|
|
|
auto new_transdata_node = func_graph->NewCNode({NewValueNode(prim::KPrimTransData), new_cast_node}); |
|
|
|
SetNodeInfo(transdata_node, cast_node, new_transdata_node); |
|
|
|
bool changed = false; |
|
|
|
auto todos = TopoSort(func_graph->get_return()); |
|
|
|
for (const auto &anf_node : todos) { |
|
|
|
auto node = anf_node->cast<CNodePtr>(); |
|
|
|
if (node == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
(void)mng->Replace(cast_node, new_transdata_node); |
|
|
|
changed = true; |
|
|
|
if (AnfAlgo::IsGraphKernel(node)) { |
|
|
|
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); |
|
|
|
bool need_traverse = true; |
|
|
|
while (need_traverse) { |
|
|
|
need_traverse = ReorderCastTypeInsensitive(sub_func_graph); |
|
|
|
if (need_traverse) { |
|
|
|
changed = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return changed; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
bool ReorderOps::Run(const FuncGraphPtr &func_graph) { return ReorderTransDataCast(func_graph); } |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |