Browse Source

!2767 fix hook operator compare issue

Merge pull request !2767 from wangqiuliang/fix-hook-operator-compare-issue
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
528ce2b96b
2 changed files with 15 additions and 17 deletions
  1. +12
    -17
      mindspore/ccsrc/optimizer/ad/kprim.cc
  2. +3
    -0
      mindspore/common/tensor.py

+ 12
- 17
mindspore/ccsrc/optimizer/ad/kprim.cc View File

@@ -106,28 +106,23 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
}

FuncGraphPtr bprop_fg = nullptr;
auto iter = bprop_registry_.find(prim);
if (iter != bprop_registry_.end()) {
bprop_fg = iter->second;
}
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
bprop_fg = BpropCut(value_node, resources);
} else {
auto iter = bprop_registry_.find(prim);
if (iter != bprop_registry_.end()) {
bprop_fg = iter->second;
}

if (bprop_fg == nullptr) {
bool is_faked_bprop = false;
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
bprop_fg = BpropCut(value_node, resources);
} else {
if (bprop_fg == nullptr) {
bprop_fg = GetBprop(prim);
if (bprop_fg == nullptr) {
if (bprop_fg != nullptr) {
// Set bprop_g graph cache
bprop_registry_[prim] = bprop_fg;
} else {
bprop_fg = FakeBprop(value_node, resources);
is_faked_bprop = true;
}
}

// To support primitives with variable params, do not cache faked bprop
if (!is_faked_bprop) {
// Set bprop_g graph cache
bprop_registry_[prim] = bprop_fg;
}
}

auto expanded_fg = BpropToK(prim, bprop_fg);


+ 3
- 0
mindspore/common/tensor.py View File

@@ -109,6 +109,9 @@ class Tensor(Tensor_):
out = tensor_operator_registry.get('__neg__')(self)
return out

def __pos__(self):
return self

def __iadd__(self, other):
return self.__add__(other)



Loading…
Cancel
Save