Browse Source

fix bprop cache caused error with variable params

tags/v0.2.0-alpha
义峰潘 panyifeng 6 years ago
parent
commit
2bef22d8a3
2 changed files with 19 additions and 2 deletions
  1. +7
    -2
      mindspore/ccsrc/optimizer/ad/kprim.cc
  2. +12
    -0
      tests/ut/python/pynative_mode/test_stop_gradient.py

+ 7
- 2
mindspore/ccsrc/optimizer/ad/kprim.cc View File

@@ -92,9 +92,11 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
return nullptr;
}

bool is_faked_bprop = false;
auto bprop_fg = GetBprop(prim);
if (bprop_fg == nullptr) {
bprop_fg = FakeBprop(value_node, resources);
is_faked_bprop = true;
}

auto expanded_fg = BpropToK(prim, bprop_fg);
@@ -104,8 +106,11 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
<< trace::GetDebugInfo(bprop_fg->debug_info());
}

// Set bprop_g graph cache
bprop_registry_[prim] = expanded_fg;
// To support primitives with variable params, do not cache faked bprop
if (!is_faked_bprop) {
// Set bprop_g graph cache
bprop_registry_[prim] = expanded_fg;
}
return expanded_fg;
}



+ 12
- 0
tests/ut/python/pynative_mode/test_stop_gradient.py View File

@@ -366,3 +366,15 @@ def test_stop_gradient_11():
with pytest.raises(RuntimeError):
bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)),
Tensor(np.ones([2]).astype(np.float32)))

def test_stop_print():
class StopPrint(nn.Cell):
def __init__(self):
super(StopPrint, self).__init__()
self.printm = P.Print()
def construct(self, x, y):
self.printm("StopPrint", x)
self.printm(y)
return x, y
C.grad_all(StopPrint())(Tensor(np.ones([2]).astype(np.float32)),
Tensor(np.ones([2]).astype(np.float32)))

Loading…
Cancel
Save