|
|
|
@@ -65,8 +65,8 @@ void SetTransNodeAttr(const CNodePtr &trans_node) { |
|
|
|
|
|
|
|
std::string InitDefaultFormat(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr("io_format", node->cast<CNodePtr>())) { |
|
|
|
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, "io_format"); |
|
|
|
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr(kAttrFormat, node->cast<CNodePtr>())) { |
|
|
|
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, kAttrFormat); |
|
|
|
if (attr == kOpFormat_NCDHW) { |
|
|
|
return kOpFormat_NCDHW; |
|
|
|
} |
|
|
|
@@ -127,11 +127,11 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An |
|
|
|
std::string output_format = AnfAlgo::GetOutputFormat(node, 0); |
|
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0); |
|
|
|
if (output_format == kOpFormat_NC1KHKWHWC0) { |
|
|
|
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " |
|
|
|
MS_LOG(EXCEPTION) << "Got the hw format " << output_format << "when insert the transdata node " |
|
|
|
<< node->DebugString() << " trace: " << trace::DumpSourceLines(node); |
|
|
|
} |
|
|
|
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { |
|
|
|
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; |
|
|
|
MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0"; |
|
|
|
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false); |
|
|
|
} |
|
|
|
return node; |
|
|
|
@@ -364,7 +364,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod |
|
|
|
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); |
|
|
|
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second); |
|
|
|
// In graph kernel, we check parameter, |
|
|
|
// the eliminate pass will not eliminate this case, so we just do not insert the noused cast. |
|
|
|
// the eliminate pass will not eliminate this case, so we just do not insert the no used cast. |
|
|
|
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) { |
|
|
|
new_inputs.push_back(cur_input); |
|
|
|
} else if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) { |
|
|
|
|