|
|
|
@@ -18,7 +18,6 @@ |
|
|
|
#include <memory> |
|
|
|
#include <vector> |
|
|
|
#include <string> |
|
|
|
#include <algorithm> |
|
|
|
#include <unordered_set> |
|
|
|
#include "base/core_ops.h" |
|
|
|
#include "utils/utils.h" |
|
|
|
@@ -89,7 +88,7 @@ bool CheckInputTypeConsistent(const CNodePtr &node, const std::vector<size_t> &c |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const TypeId &node_type) { |
|
|
|
void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const NodeIOInfo &node_io_info) { |
|
|
|
MS_EXCEPTION_IF_NULL(orig_node); |
|
|
|
MS_EXCEPTION_IF_NULL(new_node); |
|
|
|
|
|
|
|
@@ -100,32 +99,19 @@ void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const Type |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr new_abstract{nullptr}; |
|
|
|
std::vector<std::string> inputs_format; |
|
|
|
std::vector<std::string> outputs_format; |
|
|
|
std::vector<TypeId> inputs_device_type; |
|
|
|
std::vector<TypeId> outputs_device_type{node_type}; |
|
|
|
KernelType kernel_type{AnfAlgo::GetKernelType(orig_node)}; |
|
|
|
kernel::OpPattern op_pattern{AnfAlgo::GetOpPattern(orig_node)}; |
|
|
|
kernel::FusionType fusion_type{AnfAlgo::GetFusionType(orig_node)}; |
|
|
|
kernel::Processor processor{AnfAlgo::GetProcessor(orig_node)}; |
|
|
|
|
|
|
|
auto node_data_inputs_num = AnfAlgo::GetInputNum(new_node); |
|
|
|
for (size_t i = 0; i < node_data_inputs_num; ++i) { |
|
|
|
auto node_input = AnfAlgo::GetInputNode(new_node, i); |
|
|
|
auto node_input_format = AnfAlgo::GetOutputFormat(node_input, 0); |
|
|
|
auto node_input_type = AnfAlgo::GetOutputDeviceDataType(node_input, 0); |
|
|
|
inputs_format.push_back(node_input_format); |
|
|
|
inputs_device_type.push_back(node_input_type); |
|
|
|
if (node_io_info.outputs_type.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Can not set empty output type of new node from " << orig_node->fullname_with_scope(); |
|
|
|
} |
|
|
|
if (node_name == "Cast") { |
|
|
|
auto node_input = AnfAlgo::GetInputNode(new_node, 0); |
|
|
|
new_abstract = |
|
|
|
std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_type), node_input->abstract()->BuildShape()); |
|
|
|
outputs_format.push_back(AnfAlgo::GetOutputFormat(node_input, 0)); |
|
|
|
MS_EXCEPTION_IF_NULL(node_input); |
|
|
|
MS_EXCEPTION_IF_NULL(node_input->abstract()); |
|
|
|
new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]), |
|
|
|
node_input->abstract()->BuildShape()); |
|
|
|
} else { |
|
|
|
new_abstract = |
|
|
|
std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_type), orig_node->abstract()->BuildShape()); |
|
|
|
outputs_format.push_back(AnfAlgo::GetOutputFormat(orig_node, 0)); |
|
|
|
MS_EXCEPTION_IF_NULL(orig_node->abstract()); |
|
|
|
new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]), |
|
|
|
orig_node->abstract()->BuildShape()); |
|
|
|
} |
|
|
|
|
|
|
|
// Set abstract info |
|
|
|
@@ -135,14 +121,14 @@ void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const Type |
|
|
|
// Set kernel build info |
|
|
|
new_node->set_kernel_info(std::make_shared<device::KernelInfo>()); |
|
|
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder; |
|
|
|
info_builder.SetInputsFormat(inputs_format); |
|
|
|
info_builder.SetInputsDeviceType(inputs_device_type); |
|
|
|
info_builder.SetOutputsFormat(outputs_format); |
|
|
|
info_builder.SetOutputsDeviceType(outputs_device_type); |
|
|
|
info_builder.SetKernelType(kernel_type); |
|
|
|
info_builder.SetOpPattern(op_pattern); |
|
|
|
info_builder.SetFusionType(fusion_type); |
|
|
|
info_builder.SetProcessor(processor); |
|
|
|
info_builder.SetInputsFormat(node_io_info.inputs_format); |
|
|
|
info_builder.SetInputsDeviceType(node_io_info.inputs_type); |
|
|
|
info_builder.SetOutputsFormat(node_io_info.outputs_format); |
|
|
|
info_builder.SetOutputsDeviceType(node_io_info.outputs_type); |
|
|
|
info_builder.SetKernelType(AnfAlgo::GetKernelType(orig_node)); |
|
|
|
info_builder.SetOpPattern(AnfAlgo::GetOpPattern(orig_node)); |
|
|
|
info_builder.SetFusionType(AnfAlgo::GetFusionType(orig_node)); |
|
|
|
info_builder.SetProcessor(AnfAlgo::GetProcessor(orig_node)); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), new_node.get()); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
@@ -156,16 +142,22 @@ void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::v |
|
|
|
MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size " |
|
|
|
<< new_input_at_indexes.size(); |
|
|
|
} |
|
|
|
if (!new_inputs->empty()) { |
|
|
|
new_inputs->resize(0); |
|
|
|
|
|
|
|
auto node_inputs_num = node->size(); |
|
|
|
if (node_inputs_num == 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope(); |
|
|
|
} |
|
|
|
|
|
|
|
// node's inputs at indexes change to new_input_at_indexes |
|
|
|
if (!new_inputs->empty()) { |
|
|
|
new_inputs->resize(0); |
|
|
|
} |
|
|
|
new_inputs->push_back(node->input(0)); |
|
|
|
std::unordered_set<size_t> indexes_set(indexes.begin(), indexes.end()); |
|
|
|
auto node_inputs_num = node->size(); |
|
|
|
size_t idx = 0; |
|
|
|
for (size_t i = 0; i < node_inputs_num; ++i) { |
|
|
|
if (indexes_set.find(i) == indexes_set.end()) { |
|
|
|
for (size_t i = 1; i < node_inputs_num; ++i) { |
|
|
|
size_t data_idx = i - 1; |
|
|
|
if (indexes_set.find(data_idx) == indexes_set.end()) { |
|
|
|
new_inputs->push_back(node->input(i)); |
|
|
|
} else { |
|
|
|
new_inputs->push_back(new_input_at_indexes[idx++]); |
|
|
|
@@ -173,13 +165,57 @@ void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::v |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ReorderOps::SetTypeInsensitiveNodeInputsInfo(const CNodePtr &node, const std::vector<size_t> &indexes, |
|
|
|
const std::vector<AnfNodePtr> &input_at_indexes, |
|
|
|
NodeIOInfo *new_inputs_info, bool from_input) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(new_inputs_info); |
|
|
|
if (indexes.size() != input_at_indexes.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size " |
|
|
|
<< input_at_indexes.size(); |
|
|
|
} |
|
|
|
|
|
|
|
auto node_inputs_num = node->size(); |
|
|
|
if (node_inputs_num == 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope(); |
|
|
|
} |
|
|
|
|
|
|
|
// node's inputs info at indexes change to input_at_indexes's input or output info |
|
|
|
new_inputs_info->inputs_format.resize(0); |
|
|
|
new_inputs_info->inputs_type.resize(0); |
|
|
|
std::unordered_set<size_t> indexes_set(indexes.begin(), indexes.end()); |
|
|
|
size_t idx = 0; |
|
|
|
for (size_t data_idx = 0; data_idx < node_inputs_num - 1; ++data_idx) { |
|
|
|
if (indexes_set.find(data_idx) == indexes_set.end()) { |
|
|
|
new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(node, data_idx)); |
|
|
|
new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(node, data_idx)); |
|
|
|
} else { |
|
|
|
if (from_input) { |
|
|
|
new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(input_at_indexes[idx], 0)); |
|
|
|
new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(input_at_indexes[idx], 0)); |
|
|
|
} else { |
|
|
|
new_inputs_info->inputs_format.push_back(AnfAlgo::GetOutputFormat(input_at_indexes[idx], 0)); |
|
|
|
new_inputs_info->inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(input_at_indexes[idx], 0)); |
|
|
|
} |
|
|
|
idx++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng, |
|
|
|
const CNodePtr &node) { |
|
|
|
// Limitation: Current cast node is CAST_DOWN. |
|
|
|
if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN) { |
|
|
|
// Limitation: |
|
|
|
// Current cast node is CAST_DOWN. |
|
|
|
// Cast node will not change the input format. |
|
|
|
if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN || |
|
|
|
AnfAlgo::GetInputFormat(node, 0) != AnfAlgo::GetOutputFormat(node, 0)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto large_type = AnfAlgo::GetInputDeviceDataType(node, 0); |
|
|
|
auto small_type = AnfAlgo::GetOutputDeviceDataType(node, 0); |
|
|
|
auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0); |
|
|
|
|
|
|
|
auto node_input = AnfAlgo::GetInputNode(node, 0); |
|
|
|
auto type_insens_node = node_input->cast<CNodePtr>(); |
|
|
|
// Limitation: |
|
|
|
@@ -190,11 +226,9 @@ bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(node, 0); |
|
|
|
auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(node, 0); |
|
|
|
auto op_input_indexes = GetOpDataInputIndexes(type_insens_node); |
|
|
|
// Limitation: Type insensitive node's inputs have same data type. |
|
|
|
if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, cast_input_type)) { |
|
|
|
// Limitation: Type insensitive node's inputs are the large type. |
|
|
|
if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, large_type)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -202,17 +236,23 @@ bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, |
|
|
|
for (const auto &index : op_input_indexes) { |
|
|
|
auto new_cast_node = |
|
|
|
func_graph->NewCNode({NewValueNode(prim::kPrimCast), AnfAlgo::GetInputNode(type_insens_node, index)}); |
|
|
|
SetNodeInfo(node, new_cast_node, cast_out_type); |
|
|
|
NodeIOInfo cast_io_info; |
|
|
|
cast_io_info.inputs_format.push_back(AnfAlgo::GetInputFormat(type_insens_node, index)); |
|
|
|
cast_io_info.outputs_format = cast_io_info.inputs_format; |
|
|
|
cast_io_info.inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(type_insens_node, index)); |
|
|
|
cast_io_info.outputs_type.push_back(small_type); |
|
|
|
SetNodeInfo(node, new_cast_node, cast_io_info); |
|
|
|
new_cast_nodes.push_back(new_cast_node); |
|
|
|
} |
|
|
|
|
|
|
|
std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(), |
|
|
|
[](const size_t &idx) { return idx + 1; }); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> type_insens_node_new_inputs; |
|
|
|
SetTypeInsensitiveNodeInputs(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_node_new_inputs); |
|
|
|
NodeIOInfo type_insens_io_info; |
|
|
|
type_insens_io_info.outputs_format.push_back(pattern_output_format); |
|
|
|
type_insens_io_info.outputs_type.push_back(small_type); |
|
|
|
SetTypeInsensitiveNodeInputsInfo(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_io_info, false); |
|
|
|
auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs); |
|
|
|
SetNodeInfo(type_insens_node, new_type_insens_node, cast_out_type); |
|
|
|
SetNodeInfo(type_insens_node, new_type_insens_node, type_insens_io_info); |
|
|
|
|
|
|
|
(void)mng->Replace(node, new_type_insens_node); |
|
|
|
return true; |
|
|
|
@@ -227,14 +267,16 @@ bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, co |
|
|
|
// Limitation: |
|
|
|
// Certain inputs of type insensitive node are cast node. |
|
|
|
// Cast nodes are CAST_UP. |
|
|
|
// Cast nodes will not change the input format. |
|
|
|
// All these cast nodes are only used by current type insensitive node. |
|
|
|
std::vector<CNodePtr> cast_nodes; |
|
|
|
std::vector<AnfNodePtr> cast_nodes; |
|
|
|
std::vector<AnfNodePtr> cast_input_nodes; |
|
|
|
auto op_input_indexes = GetOpDataInputIndexes(node); |
|
|
|
for (const auto &index : op_input_indexes) { |
|
|
|
auto node_input = AnfAlgo::GetInputNode(node, index); |
|
|
|
auto cast_node = node_input->cast<CNodePtr>(); |
|
|
|
if (cast_node != nullptr && IsPrimitiveCNode(cast_node, prim::kPrimCast) && GetCastType(cast_node) == CAST_UP && |
|
|
|
AnfAlgo::GetInputFormat(node, 0) == AnfAlgo::GetOutputFormat(node, 0) && |
|
|
|
mng->node_users()[cast_node].size() == 1) { |
|
|
|
cast_nodes.push_back(cast_node); |
|
|
|
cast_input_nodes.push_back(AnfAlgo::GetInputNode(cast_node, 0)); |
|
|
|
@@ -244,29 +286,37 @@ bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, co |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0); |
|
|
|
auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0); |
|
|
|
auto small_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0); |
|
|
|
auto large_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0); |
|
|
|
auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0); |
|
|
|
|
|
|
|
// Limitation: All these cast nodes cast same type to another type. |
|
|
|
if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&cast_input_type](const CNodePtr &cast_node) { |
|
|
|
return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == cast_input_type; |
|
|
|
if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&small_type](const AnfNodePtr &cast_node) { |
|
|
|
return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == small_type; |
|
|
|
})) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
// Limitation: Type insensitive node's inputs have same data type. |
|
|
|
if (!CheckInputTypeConsistent(node, op_input_indexes, cast_out_type)) { |
|
|
|
if (!CheckInputTypeConsistent(node, op_input_indexes, large_type)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(), |
|
|
|
[](const size_t &idx) { return idx + 1; }); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> type_insens_node_new_inputs; |
|
|
|
SetTypeInsensitiveNodeInputs(node, op_input_indexes, cast_input_nodes, &type_insens_node_new_inputs); |
|
|
|
auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs); |
|
|
|
SetNodeInfo(node, new_type_insens_node, cast_input_type); |
|
|
|
NodeIOInfo type_insens_io_info; |
|
|
|
type_insens_io_info.outputs_format.push_back(pattern_output_format); |
|
|
|
type_insens_io_info.outputs_type.push_back(small_type); |
|
|
|
SetTypeInsensitiveNodeInputsInfo(node, op_input_indexes, cast_nodes, &type_insens_io_info, true); |
|
|
|
SetNodeInfo(node, new_type_insens_node, type_insens_io_info); |
|
|
|
|
|
|
|
auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), new_type_insens_node}); |
|
|
|
SetNodeInfo(cast_nodes[0], new_cast_node, cast_out_type); |
|
|
|
NodeIOInfo cast_io_info; |
|
|
|
cast_io_info.inputs_format.push_back(pattern_output_format); |
|
|
|
cast_io_info.outputs_format = cast_io_info.inputs_format; |
|
|
|
cast_io_info.inputs_type.push_back(small_type); |
|
|
|
cast_io_info.outputs_type.push_back(large_type); |
|
|
|
SetNodeInfo(cast_nodes[0]->cast<CNodePtr>(), new_cast_node, cast_io_info); |
|
|
|
|
|
|
|
(void)mng->Replace(node, new_cast_node); |
|
|
|
return true; |
|
|
|
|