Browse Source

cache bprop instead of fprop

tags/v0.6.0-beta
wuyongkang 5 years ago
parent
commit
e606068585
4 changed files with 32 additions and 32 deletions
  1. +0
    -1
      mindspore/ccsrc/optimizer/ad/dfunctor.cc
  2. +4
    -4
      mindspore/ccsrc/optimizer/ad/dfunctor.h
  3. +24
    -26
      mindspore/ccsrc/optimizer/ad/kprim.cc
  4. +4
    -1
      tests/ut/cpp/operator/grad_implementations_test.cc

+ 0
- 1
mindspore/ccsrc/optimizer/ad/dfunctor.cc View File

@@ -424,7 +424,6 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
}
auto k_prim = g_k_prims.KPrimitive(value_node, resources_);
if (k_prim != nullptr) {
k_prim = BasicClone(k_prim);
return NewValueNode(k_prim);
}
// When failed to find k_prim, try k_meta.


+ 4
- 4
mindspore/ccsrc/optimizer/ad/dfunctor.h View File

@@ -47,12 +47,12 @@ struct PrimitiveTotalEqual {
return false;
}

for (auto &attr : attrs1) {
if (!t2->HasAttr(attr.first)) {
for (auto &attr1 : attrs1) {
if (!t2->HasAttr(attr1.first)) {
return false;
}

if (!(*(attr.second) == *(t2->GetAttr(attr.first)))) {
if (!(*(attr1.second) == *(t2->GetAttr(attr1.first)))) {
return false;
}
}
@@ -61,7 +61,7 @@ struct PrimitiveTotalEqual {
}
};

using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher>;
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
class KPrim;
extern KPrim g_k_prims;
class DFunctor;


+ 24
- 26
mindspore/ccsrc/optimizer/ad/kprim.cc View File

@@ -96,34 +96,37 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
MS_LOG(EXCEPTION) << "Primitive node is not valid.";
}

auto prim = value_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim);

auto iter = bprop_registry_.find(prim);
if (iter != bprop_registry_.end()) {
return iter->second;
}

if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == "switch_layer") {
auto prim = GetValueNode<PrimitivePtr>(value_node);
if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) {
auto fprop = GetFprop(prim);
fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
bprop_registry_[prim::kPrimSwitchLayer] = fprop;
return fprop;
}

if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
} else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
return nullptr;
}

bool is_faked_bprop = false;
FuncGraphPtr bprop_fg = nullptr;
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == "HookBackward") {
bprop_fg = BpropCut(value_node, resources);
} else {
bprop_fg = GetBprop(prim);
if (bprop_fg == nullptr) {
bprop_fg = FakeBprop(value_node, resources);
is_faked_bprop = true;
auto iter = bprop_registry_.find(prim);
if (iter != bprop_registry_.end()) {
bprop_fg = iter->second;
}

if (bprop_fg == nullptr) {
bool is_faked_bprop = false;
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
bprop_fg = BpropCut(value_node, resources);
} else {
bprop_fg = GetBprop(prim);
if (bprop_fg == nullptr) {
bprop_fg = FakeBprop(value_node, resources);
is_faked_bprop = true;
}
}

// To support primitives with variable params, do not cache faked bprop
if (!is_faked_bprop) {
// Set bprop_g graph cache
bprop_registry_[prim] = bprop_fg;
}
}

@@ -134,11 +137,6 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
<< trace::GetDebugInfo(bprop_fg->debug_info());
}

// To support primitives with variable params, do not cache faked bprop
if (!is_faked_bprop) {
// Set bprop_g graph cache
bprop_registry_[prim] = expanded_fg;
}
return expanded_fg;
}



+ 4
- 1
tests/ut/cpp/operator/grad_implementations_test.cc View File

@@ -38,7 +38,10 @@ TEST_F(TestGradImplementations, TestGetAugmentedGraph) {
draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg);

auto fg1 = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr);
ASSERT_TRUE(fg == fg1);

FuncGraphPairMapEquiv equiv_graph;
NodeMapEquiv equiv_node;
ASSERT_TRUE(Isomorphic(fg, fg1, &equiv_graph, &equiv_node));
}

} // namespace prim


Loading…
Cancel
Save