Browse Source

!9711 change cell id for pynative to fix clipnorm execute error

From: @chujinjin
Reviewed-by: @limingqi107,@zhoufeng54
Signed-off-by: @zhoufeng54
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
b5be4e0402
2 changed files with 4 additions and 5 deletions
  1. +4
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  2. +0
    -3
      mindspore/nn/layer/basic.py

+ 4
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -1384,12 +1384,14 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &
std::string arg_id = GetId(args[i]);
auto it = node_abs_map_.find(arg_id);
if (it != node_abs_map_.end()) {
cell_id += it->second->ToString();
cell_id += "_" + it->second->BuildShape()->ToString();
cell_id += "_" + it->second->BuildType()->ToString();
} else {
auto abs = PyAttrValue(args[i])->ToAbstract();
auto config = abstract::AbstractBase::kBroadenTensorOnly;
abs = abs->Broaden(config);
cell_id += abs->ToString();
cell_id += "_" + abs->BuildShape()->ToString();
cell_id += "_" + abs->BuildType()->ToString();
node_abs_map_[arg_id] = abs;
}
}


+ 0
- 3
mindspore/nn/layer/basic.py View File

@@ -29,7 +29,6 @@ from mindspore.ops.primitive import constexpr, Primitive
from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register
from mindspore._checkparam import Rel, Validator
from mindspore.common.api import ms_function
from mindspore import context
from ..cell import Cell
from .activation import get_activation
@@ -413,9 +412,7 @@ class ClipByNorm(Cell):
self.expand_dims = P.ExpandDims()
self.dtype = P.DType()

@ms_function
def construct(self, x, clip_norm):
"""add ms_function decorator for pynative mode"""
mul_x = F.square(x)
l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32)
cond = self.greater_(l2sum, 0)


Loading…
Cancel
Save