From 5d301092f6d7e299df8c2b375770fc7f9ca8003d Mon Sep 17 00:00:00 2001 From: kingfo Date: Tue, 30 Jun 2020 20:58:55 +0800 Subject: [PATCH] fix hook operator compare issue --- mindspore/ccsrc/optimizer/ad/kprim.cc | 29 +++++++++++---------------- mindspore/common/tensor.py | 3 +++ 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/mindspore/ccsrc/optimizer/ad/kprim.cc b/mindspore/ccsrc/optimizer/ad/kprim.cc index 791279b1a1..4141fb5413 100644 --- a/mindspore/ccsrc/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/optimizer/ad/kprim.cc @@ -106,28 +106,23 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R } FuncGraphPtr bprop_fg = nullptr; - auto iter = bprop_registry_.find(prim); - if (iter != bprop_registry_.end()) { - bprop_fg = iter->second; - } + if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) { + bprop_fg = BpropCut(value_node, resources); + } else { + auto iter = bprop_registry_.find(prim); + if (iter != bprop_registry_.end()) { + bprop_fg = iter->second; + } - if (bprop_fg == nullptr) { - bool is_faked_bprop = false; - if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) { - bprop_fg = BpropCut(value_node, resources); - } else { + if (bprop_fg == nullptr) { bprop_fg = GetBprop(prim); - if (bprop_fg == nullptr) { + if (bprop_fg != nullptr) { + // Set bprop_g graph cache + bprop_registry_[prim] = bprop_fg; + } else { bprop_fg = FakeBprop(value_node, resources); - is_faked_bprop = true; } } - - // To support primitives with variable params, do not cache faked bprop - if (!is_faked_bprop) { - // Set bprop_g graph cache - bprop_registry_[prim] = bprop_fg; - } } auto expanded_fg = BpropToK(prim, bprop_fg); diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 4bb845af55..043ab4f6cf 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -109,6 +109,9 @@ class Tensor(Tensor_): out = tensor_operator_registry.get('__neg__')(self) return out + def __pos__(self): + return self + def __iadd__(self, other): return self.__add__(other)