Browse Source

Dont liftfv before grad in higher order grad if closure is kept as bprop of Partial primitive is not supported

tags/v1.6.0
zhousiyi 4 years ago
parent
commit
d9aa48bc64
7 changed files with 47 additions and 21 deletions
  1. +17
    -4
      mindspore/ccsrc/frontend/optimizer/ad/grad.cc
  2. +2
    -4
      mindspore/ccsrc/frontend/optimizer/ad/grad.h
  3. +11
    -8
      mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc
  4. +11
    -0
      mindspore/ccsrc/frontend/optimizer/optimizer.h
  5. +1
    -1
      mindspore/ccsrc/pipeline/jit/action.cc
  6. +1
    -1
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  7. +4
    -3
      tests/ut/cpp/optimizer/ad/ad_test.cc

+ 17
- 4
mindspore/ccsrc/frontend/optimizer/ad/grad.cc View File

@@ -17,6 +17,7 @@
#include "frontend/optimizer/ad/grad.h"
#include "frontend/optimizer/ad/dfunctor.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/dead_node_eliminate.h"
#include "ir/func_graph_cloner.h"
#include "utils/ms_context.h"
#include "utils/symbolic.h"
@@ -24,7 +25,7 @@
namespace mindspore {
namespace ad {
namespace {
FuncGraphPtr PartialEliminateOptPass(const ResourcePtr &resource, const FuncGraphPtr &func_graph) {
FuncGraphPtr PartialEliminateOptPass(const pipeline::ResourcePtr &resource, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(resource);

opt::irpass::OptimizeIRPassLib irpass;
@@ -68,20 +69,32 @@ FuncGraphPtr LiftFv(const pipeline::ResourceBasePtr &resource, const FuncGraphPt
}
} // namespace

FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) {
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimzer, bool is_top) {
MS_EXCEPTION_IF_NULL(func_graph);
auto gradkv = func_graph->transforms().find("grad");
if (gradkv != func_graph->transforms().end()) {
return gradkv->second.func_graph();
}

const auto &resources = optimzer->resource();
auto manager_ptr = resources->manager();
MS_EXCEPTION_IF_NULL(manager_ptr);
manager_ptr->AddFuncGraph(func_graph);

FuncGraphPtr grad_fg = func_graph;
if (func_graph->func_graphs_used().size() != 0) {
grad_fg = LiftFv(resources, func_graph);
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
if (enable_closure) {
if (func_graph->func_graphs_used().size() != 0 && optimzer->is_first_order_j()) {
lift_fv_before_grad = true;
grad_fg = LiftFv(resources, func_graph);
} else {
lift_fv_before_grad = false;
opt::EliminateDeadNode(grad_fg);
}
} else {
if (func_graph->func_graphs_used().size() != 0) {
grad_fg = LiftFv(resources, func_graph);
}
}
auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {


+ 2
- 4
mindspore/ccsrc/frontend/optimizer/ad/grad.h View File

@@ -22,13 +22,11 @@

#include "ir/anf.h"
#include "ir/meta_func_graph.h"
#include "pipeline/jit/resource.h"
#include "frontend/optimizer/optimizer.h"

namespace mindspore {
namespace ad {
using ResourcePtr = std::shared_ptr<pipeline::Resource>;

FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top = true);
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer, bool is_top = true);
FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &);
void CleanRes();


+ 11
- 8
mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc View File

@@ -83,7 +83,8 @@ void CheckSwitchWithSideEffect(const FuncGraphPtr &fg) {
}
}

AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) {
AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const OptimizerPtr &optimizer) {
AnfNodePtr expanded_node = nullptr;
if (IsValueNode<FuncGraph>(vnode)) {
ScopeGuard scope_guard(vnode->scope());
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
@@ -92,13 +93,15 @@ AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &r
CheckSwitchWithSideEffect(func_graph);
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandJ now";
auto newfg = ad::Grad(func_graph, resource);
return NewValueNode(newfg);
auto newfg = ad::Grad(func_graph, optimizer);
expanded_node = NewValueNode(newfg);
} else if (IsValueNode<Primitive>(vnode)) {
expanded_node = ExpandJPrimitive(vnode, optimizer->resource());
} else {
return nullptr;
}
if (IsValueNode<Primitive>(vnode)) {
return ExpandJPrimitive(vnode, resource);
}
return nullptr;
optimizer->set_is_first_order_j(false);
return expanded_node;
}
} // namespace internal

@@ -122,7 +125,7 @@ bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr
bool change = false;
auto manager = optimizer->manager();
for (auto &j_node : todo) {
auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer->resource());
auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer);
manager->Replace(j_node, expanded_j);
change = true;
}


+ 11
- 0
mindspore/ccsrc/frontend/optimizer/optimizer.h View File

@@ -142,6 +142,12 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
return optimizer;
}

static std::shared_ptr<Optimizer> MakeEmptyOptimizer(const pipeline::ResourceBasePtr resource_ptr) {
OptimizerPtr optimizer = std::make_shared<Optimizer>("empty", resource_ptr, false);
optimizer->Init(OptPassGroupMap{}, false);
return optimizer;
}

FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) {
if (!is_enable_) {
return func_graph;
@@ -240,6 +246,9 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {

bool traverse_nodes_first() { return traverse_nodes_first_; }

bool is_first_order_j() { return is_first_order_j_; }
void set_is_first_order_j(bool is_first_order_j) { is_first_order_j_ = is_first_order_j; }

struct {
int64_t counter;
std::string name;
@@ -257,6 +266,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
bool is_enable_;
bool is_untyped_generated_;
bool traverse_nodes_first_;
// A flag to indicate if it's the first order J or innermost J in GraphMode.
bool is_first_order_j_{true};
};
} // namespace opt
} // namespace mindspore


+ 1
- 1
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -725,7 +725,7 @@ bool EliminateForwardCNode(const ResourcePtr &res) {
auto grad_exec = pynative_exec->grad_executor();
bool eliminate_forward = grad_exec->eliminate_forward();
grad_exec->set_eliminate_forward(eliminate_forward && ms_func_graph->func_graphs_used().empty());
auto grad_graph = ad::Grad(ms_func_graph, res);
auto grad_graph = ad::Grad(ms_func_graph, opt::Optimizer::MakeEmptyOptimizer(res));
MS_EXCEPTION_IF_NULL(grad_graph);
graph_executor->SetGradGraph(grad_graph, phase);
ModifyOutputNode(ms_func_graph);


+ 1
- 1
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -3167,7 +3167,7 @@ void GradExecutor::MakeNestedCnode(const py::object &cell, const py::tuple &forw
r->manager()->AddFuncGraph(first_grad_fg);
set_eliminate_forward(false);
first_grad_fg->transforms().erase(kGrad);
FuncGraphPtr second_grad_fg = ad::Grad(first_grad_fg, r);
FuncGraphPtr second_grad_fg = ad::Grad(first_grad_fg, opt::Optimizer::MakeEmptyOptimizer(r));
set_eliminate_forward(true);
DumpGraphIR("second_grad_fg.ir", second_grad_fg);
r->Clean();


+ 4
- 3
tests/ut/cpp/optimizer/ad/ad_test.cc View File

@@ -28,6 +28,7 @@
#include "pipeline/jit/parse/parse.h"
#include "debug/draw.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/optimizer.h"

namespace mindspore {
namespace ad {
@@ -44,7 +45,7 @@ class TestAD : public UT::Common {
FuncGraphPtr g = getPyFun(testCase);
resourcePtr->manager()->RemoveRoots();
resourcePtr->manager()->AddFuncGraph(g, true);
FuncGraphPtr dg = Grad(g, resourcePtr);
FuncGraphPtr dg = Grad(g, opt::Optimizer::MakeEmptyOptimizer(resourcePtr));
AssertExpect(testCase, dg);
}

@@ -188,8 +189,8 @@ TEST_F(TestAD, test_prim_switch) {

TEST_F(TestAD, test_grad_cache) {
FuncGraphPtr g = getPyFun("test_null");
FuncGraphPtr dg1 = Grad(g, resourcePtr);
FuncGraphPtr dg2 = Grad(g, resourcePtr);
FuncGraphPtr dg1 = Grad(g, opt::Optimizer::MakeEmptyOptimizer(resourcePtr));
FuncGraphPtr dg2 = Grad(g, opt::Optimizer::MakeEmptyOptimizer(resourcePtr));
ASSERT_TRUE(dg1 == dg2);
}



Loading…
Cancel
Save