|
|
|
@@ -950,7 +950,7 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_ |
|
|
|
} |
|
|
|
(*output_desc)[kJFormat] = format; |
|
|
|
// special node |
|
|
|
if (fusion_data_type == kFusionAddN && format == kOpFormat_NC1HWC0) { |
|
|
|
if ((fusion_data_type == kFusionAddN || fusion_data_type == kFusionAdd) && shape.size() == 5) { |
|
|
|
std::vector<size_t> spec_shape = {}; |
|
|
|
spec_shape.emplace_back(shape[0]); |
|
|
|
spec_shape.emplace_back(shape[1]); |
|
|
|
@@ -995,7 +995,8 @@ void TbeKernelBuild::GenReusedOutputDesc(const std::shared_ptr<mindspore::AnfNod |
|
|
|
bool TbeKernelBuild::GetSpecInputLayers(const std::string &op_name, |
|
|
|
const std::vector<mindspore::AnfNodePtr> &reorder_layer, |
|
|
|
std::map<const AnfNodePtr, FusionDataType> *spec_data_input) { |
|
|
|
if ((op_name == kReluGradV2OpName || op_name == kAddNOpName) && reorder_layer.empty()) { |
|
|
|
if ((op_name == kReluGradV2OpName || op_name == kAddNOpName || op_name == kTensorAddOpName) && |
|
|
|
reorder_layer.empty()) { |
|
|
|
MS_LOG(INFO) << "Fusion error: node(" << op_name << " )'s input is null. "; |
|
|
|
return false; |
|
|
|
} |
|
|
|
@@ -1005,6 +1006,8 @@ bool TbeKernelBuild::GetSpecInputLayers(const std::string &op_name, |
|
|
|
for (const auto &it : reorder_layer) { |
|
|
|
(*spec_data_input)[it] = kFusionAddN; |
|
|
|
} |
|
|
|
} else if (op_name == kTensorAddOpName) { |
|
|
|
(*spec_data_input)[reorder_layer[0]] = kFusionAdd; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -1020,7 +1023,7 @@ bool TbeKernelBuild::GetInputLayers(const std::vector<mindspore::AnfNodePtr> &in |
|
|
|
MS_EXCEPTION_IF_NULL(spec_data_input); |
|
|
|
auto result = std::find_if(compute_nodes.begin(), compute_nodes.end(), [](const auto &it) { |
|
|
|
auto op_name = AnfAlgo::GetCNodeName(it); |
|
|
|
return op_name == kConv2DBackpropInputOpName; |
|
|
|
return (op_name == kConv2DBackpropInputOpName || op_name == kConv2DOpName); |
|
|
|
}); |
|
|
|
bool need_spec = (result != compute_nodes.end()); |
|
|
|
size_t input_size = 0; |
|
|
|
|