|
|
@@ -106,28 +106,23 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
FuncGraphPtr bprop_fg = nullptr; |
|
|
FuncGraphPtr bprop_fg = nullptr; |
|
|
auto iter = bprop_registry_.find(prim); |
|
|
|
|
|
if (iter != bprop_registry_.end()) { |
|
|
|
|
|
bprop_fg = iter->second; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) { |
|
|
|
|
|
bprop_fg = BpropCut(value_node, resources); |
|
|
|
|
|
} else { |
|
|
|
|
|
auto iter = bprop_registry_.find(prim); |
|
|
|
|
|
if (iter != bprop_registry_.end()) { |
|
|
|
|
|
bprop_fg = iter->second; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if (bprop_fg == nullptr) { |
|
|
|
|
|
bool is_faked_bprop = false; |
|
|
|
|
|
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) { |
|
|
|
|
|
bprop_fg = BpropCut(value_node, resources); |
|
|
|
|
|
} else { |
|
|
|
|
|
|
|
|
if (bprop_fg == nullptr) { |
|
|
bprop_fg = GetBprop(prim); |
|
|
bprop_fg = GetBprop(prim); |
|
|
if (bprop_fg == nullptr) { |
|
|
|
|
|
|
|
|
if (bprop_fg != nullptr) { |
|
|
|
|
|
// Set bprop_g graph cache |
|
|
|
|
|
bprop_registry_[prim] = bprop_fg; |
|
|
|
|
|
} else { |
|
|
bprop_fg = FakeBprop(value_node, resources); |
|
|
bprop_fg = FakeBprop(value_node, resources); |
|
|
is_faked_bprop = true; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// To support primitives with variable params, do not cache faked bprop |
|
|
|
|
|
if (!is_faked_bprop) { |
|
|
|
|
|
// Set bprop_g graph cache |
|
|
|
|
|
bprop_registry_[prim] = bprop_fg; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
auto expanded_fg = BpropToK(prim, bprop_fg); |
|
|
auto expanded_fg = BpropToK(prim, bprop_fg); |
|
|
|