|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -28,12 +28,29 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace irpass { |
|
|
|
// {prim::kPrimSwitchLayer, {Index, layers}} |
|
|
|
// {prim::kPrimSwitch, cond, true_branch, false_branch} |
|
|
|
class SwitchDeferInline : public AnfVisitor { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
auto true_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(2)->abstract()); |
|
|
|
if (true_abstract != nullptr) { |
|
|
|
*(true_abstract->func_graph()->switch_input()) = true; |
|
|
|
} |
|
|
|
auto false_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(3)->abstract()); |
|
|
|
if (false_abstract != nullptr) { |
|
|
|
*(false_abstract->func_graph()->switch_input()) = true; |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
// {prim::kPrimSwitchLayer, Index, layers} |
|
|
|
class SwitchLayerDeferInline : public AnfVisitor { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->inputs()[2]->abstract()); |
|
|
|
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->input(2)->abstract()); |
|
|
|
for (auto elem : tuple->elements()) { |
|
|
|
auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem); |
|
|
|
if (abstract != nullptr) { |