|
|
|
@@ -23,7 +23,6 @@ |
|
|
|
#include "schema/inner/model_generated.h" |
|
|
|
#include "tools/optimizer/common/gllo_utils.h" |
|
|
|
|
|
|
|
|
|
|
|
namespace mindspore::opt { |
|
|
|
namespace { |
|
|
|
constexpr size_t kActivationInputsLength = 2; |
|
|
|
@@ -68,14 +67,18 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c |
|
|
|
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value)); |
|
|
|
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value); |
|
|
|
MS_ASSERT(primc != nullptr); |
|
|
|
primc->SetActivationType(activation_type); |
|
|
|
return pre_node; |
|
|
|
if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { |
|
|
|
primc->SetActivationType(activation_type); |
|
|
|
return pre_node; |
|
|
|
} |
|
|
|
} else if (node_type == schema::PrimitiveType_DepthwiseConv2D) { |
|
|
|
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value)); |
|
|
|
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value); |
|
|
|
MS_ASSERT(primc != nullptr); |
|
|
|
primc->SetActivationType(activation_type); |
|
|
|
return pre_node; |
|
|
|
if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { |
|
|
|
primc->SetActivationType(activation_type); |
|
|
|
return pre_node; |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "conv activation pass match only conv2d or depthwise_conv2d "; |
|
|
|
} |
|
|
|
|