|
|
|
@@ -88,7 +88,10 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p |
|
|
|
auto total_node_count = input_node_indexes.size() + output_node_indexes.size(); |
|
|
|
size_t half_count = total_node_count / 2; |
|
|
|
if (GetCNodeTType(*node) == schema::PrimitiveType_Activation) { |
|
|
|
if (node->primitive->value.AsActivation()->type == schema::ActivationType_LEAKY_RELU) { |
|
|
|
MS_ASSERT(node != nullptr); |
|
|
|
MS_ASSERT(node->primitive != nullptr); |
|
|
|
if (node->primitive->value.AsActivation() != nullptr && |
|
|
|
node->primitive->value.AsActivation()->type == schema::ActivationType_LEAKY_RELU) { |
|
|
|
return has_trans_count >= half_count; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -223,6 +226,10 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { |
|
|
|
changed = false; |
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { |
|
|
|
auto &node = *iter; |
|
|
|
if (node == nullptr && node->primitive == nullptr) { |
|
|
|
MS_LOG(ERROR) << "node or primitive null"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
auto type = node->primitive->value.type; |
|
|
|
if (IsContain(has_insert_nodes, node.get()) || !IsContain(GetInsertOpList(), type)) { |
|
|
|
continue; |
|
|
|
|