Browse Source

raise exception when use HookBackward in graph mode

tags/v1.0.0
buxue 5 years ago
parent
commit
3fd73f9d08
6 changed files with 8 additions and 5 deletions
  1. +3
    -0
      mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
  2. +1
    -1
      mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
  3. +1
    -1
      mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
  4. +1
    -1
      mindspore/ccsrc/pipeline/jit/validator.cc
  5. +1
    -1
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  6. +1
    -1
      mindspore/nn/cell.py

+ 3
- 0
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc View File

@@ -120,6 +120,9 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R

FuncGraphPtr bprop_fg = nullptr;
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
if (MsContext::GetInstance()->get_param<int>(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) {
MS_LOG(EXCEPTION) << "HookBackward is not supported in graph mode.";
}
bprop_fg = BpropCut(value_node, resources);
} else {
auto iter = bprop_registry_.find(prim);


+ 1
- 1
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc View File

@@ -1198,7 +1198,7 @@ void ClearPrimEvaluatorMap() {
GetUniformPrimitiveToImplMap().clear();
}

bool IsInWhiteList(const PrimitivePtr primitive) {
bool IsInWhiteList(const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(primitive);

auto iter = GetPrimitiveToEvalImplMap().find(primitive);


+ 1
- 1
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h View File

@@ -111,7 +111,7 @@ class MixedPrecisionCastEvaluator : public Evaluator {
PrimitivePtr prim_;
};

bool IsInWhiteList(PrimitivePtr primitive);
bool IsInWhiteList(const PrimitivePtr &primitive);
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);

using ValuePtrList = std::vector<ValuePtr>;


+ 1
- 1
mindspore/ccsrc/pipeline/jit/validator.cc View File

@@ -47,7 +47,7 @@ void ValidateOperation(const AnfNodePtr &node) {
}

// Primitive must in whitelist
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
auto prim = GetValueNode<PrimitivePtr>(node);
if (abstract::IsInWhiteList(prim)) {
return;
}


+ 1
- 1
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -1257,7 +1257,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
if (need_replace_param) {
auto params = newfg->parameters();
auto manager = Manage({newfg}, false);
for (size_t i = 0; i < params.size(); i++) {
for (size_t i = 0; i < args.size(); i++) {
ValuePtr value = PyAttrValue(args[i]);
auto v_node = NewValueNode(value);
manager->Replace(params[i], v_node);


+ 1
- 1
mindspore/nn/cell.py View File

@@ -294,7 +294,7 @@ class Cell(Cell_):
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
for i, cell in enumerate(self.cells()):
cell.set_grad(origin_grad[i])
self._already_run = True
self._already_run = True
return output

def _add_attr(self, name, value):


Loading…
Cancel
Save