Browse Source

!4188 fix switch layer sigle prim cell

Merge pull request !4188 from riemann_penn/fix_switch_layer_sigle_prim_cell
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
0e4d2c2535
11 changed files with 90 additions and 1 deletions
  1. +3
    -0
      mindspore/ccsrc/debug/anf_ir_utils.cc
  2. +2
    -0
      mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc
  3. +5
    -0
      mindspore/ccsrc/frontend/optimizer/irpass.cc
  4. +3
    -0
      mindspore/ccsrc/frontend/optimizer/irpass.h
  5. +1
    -1
      mindspore/ccsrc/frontend/optimizer/irpass/inline.h
  6. +47
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h
  7. +1
    -0
      mindspore/ccsrc/pipeline/jit/pass.cc
  8. +1
    -0
      mindspore/core/ir/func_graph.cc
  9. +5
    -0
      mindspore/core/ir/func_graph.h
  10. +2
    -0
      mindspore/core/ir/func_graph_cloner.cc
  11. +20
    -0
      tests/ut/python/ops/test_control_ops.py

+ 3
- 0
mindspore/ccsrc/debug/anf_ir_utils.cc View File

@@ -607,6 +607,9 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun
std::vector<AnfNodePtr> parameters = func_graph->parameters();
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;

if (*(func_graph->switch_layer_input())) {
ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n";
}
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "."
<< func_graph->debug_info()->get_id() << "\n";
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {


+ 2
- 0
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc View File

@@ -49,6 +49,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
}
// To keep switch_layer's inputs from being inlined
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
TraceManager::EndTrace();

TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));


+ 5
- 0
mindspore/ccsrc/frontend/optimizer/irpass.cc View File

@@ -45,6 +45,7 @@
#include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
#include "frontend/optimizer/irpass/switch_layer_defer_inline.h"

namespace mindspore {
namespace opt {
@@ -170,6 +171,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Value_Based Eliminate
value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate",
{prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum});

// switch_layer defer inline
switch_layer_defer_inline_ =
MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
}

ResolveIRPassLib::ResolveIRPassLib() {


+ 3
- 0
mindspore/ccsrc/frontend/optimizer/irpass.h View File

@@ -113,6 +113,9 @@ class OptimizeIRPassLib {

// Value_Based Eliminate
SubstitutionPtr value_based_eliminate_;

// SwitchLayer defer inline
SubstitutionPtr switch_layer_defer_inline_;
};

// the collection of irpass for resolve action


+ 1
- 1
mindspore/ccsrc/frontend/optimizer/irpass/inline.h View File

@@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor {
}

auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub() || *(fg->switch_layer_input())) {
return nullptr;
}



+ 47
- 0
mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_

#include <vector>
#include <algorithm>

#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"

namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimSwitchLayer, {Index, layers}}
class SwitchLayerDeferInline : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
auto cnode = node->cast<CNodePtr>();
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->inputs()[2]->abstract());
for (auto elem : tuple->elements()) {
auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem);
*(abstract->func_graph()->switch_layer_input()) = true;
}
return nullptr;
}
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_

+ 1
- 0
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -90,6 +90,7 @@ bool CleanAfterOptAPass(const ResourcePtr &res) {
namespace {
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig a_1 = opt::OptPassConfig({
irpass.switch_layer_defer_inline_,
irpass.switch_simplify_,

// Safe inlining


+ 1
- 0
mindspore/core/ir/func_graph.cc View File

@@ -48,6 +48,7 @@ FuncGraph::FuncGraph()
manager_(std::weak_ptr<FuncGraphManager>()),
stub_(false) {
debug_info_ = std::make_shared<GraphDebugInfo>();
switch_layer_input_ = std::make_shared<bool>(false);
}

abstract::AbstractBasePtr FuncGraph::ToAbstract() {


+ 5
- 0
mindspore/core/ir/func_graph.h View File

@@ -353,6 +353,8 @@ class FuncGraph : public FuncGraphBase {
bool stub() const { return stub_; }
void set_stub(bool stub) { stub_ = stub; }
static void set_drawer(Drawer drawer) { drawer_ = drawer; }
std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; }
void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; }

private:
// graph is manipulated by manager and others
@@ -414,6 +416,9 @@ class FuncGraph : public FuncGraphBase {
std::list<CNodePtr> order_;
bool stub_;
inline static Drawer drawer_ = nullptr;
// Design switch_layer_input as a ptr to
// share between derived backpropagator and cloned graphs
std::shared_ptr<bool> switch_layer_input_;
};

inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {


+ 2
- 0
mindspore/core/ir/func_graph_cloner.cc View File

@@ -228,6 +228,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
(*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
(*target_func_graph)->set_is_generate(func_graph->is_generated());
(*target_func_graph)->set_stub(func_graph->stub());
(*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input());
TraceManager::EndTrace();
}

@@ -645,6 +646,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
new_func_graph->set_is_generate(func_graph->is_generated());
new_func_graph->set_stub(func_graph->stub());
new_func_graph->set_switch_layer_input(func_graph->switch_layer_input());
for (auto &item : func_graph->parameter_default_value()) {
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
}


+ 20
- 0
tests/ut/python/ops/test_control_ops.py View File

@@ -444,6 +444,26 @@ def test_index_to_switch_layer():
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))


def test_switch_layer_with_single_prim():
class SwitchLayerCell(nn.Cell):
def __init__(self):
super(SwitchLayerCell, self).__init__()
self.layers = (nn.ReLU(), nn.ReLU())
self.z3 = Parameter(
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')

def construct(self, index, x):
ret = self.layers[index](x) * self.z3
return ret

index = Tensor(0, dtype=mstype.int32)
net = SwitchLayerCell()
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))


def test_control_depend_check():
with pytest.raises(TypeError) as e:
P.ControlDepend(0.0)


Loading…
Cancel
Save