Browse Source

Fix bug of assign value to non Parameter class member

tags/v0.3.0-alpha
fary86 5 years ago
parent
commit
16b9004d53
2 changed files with 23 additions and 2 deletions
  1. +22
    -1
      mindspore/ccsrc/pipeline/parse/parse.cc
  2. +1
    -1
      tests/ut/python/pynative_mode/test_insert_grad_of.py

+ 22
- 1
mindspore/ccsrc/pipeline/parse/parse.cc View File

@@ -1136,10 +1136,31 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob
AnfNodePtr target_node = ParseExprNode(block, targ); AnfNodePtr target_node = ParseExprNode(block, targ);
MS_EXCEPTION_IF_NULL(target_node); MS_EXCEPTION_IF_NULL(target_node);


std::string attr_name = targ.attr("attr").cast<std::string>();
std::string var_name = "self."; std::string var_name = "self.";
(void)var_name.append(targ.attr("attr").cast<std::string>());
(void)var_name.append(attr_name);
MS_LOG(DEBUG) << "assign " << var_name; MS_LOG(DEBUG) << "assign " << var_name;


// Get targ location info for error printing
py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, targ);
if (location.size() < 2) {
MS_LOG(EXCEPTION) << "List size should not be less than 2.";
}
auto filename = location[0].cast<std::string>();
auto line_no = location[1].cast<int>();
// Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
if (!py::hasattr(ast()->obj(), attr_name.c_str())) {
MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined, at " << filename << ":"
<< line_no;
}
auto obj = ast()->obj().attr(attr_name.c_str());
auto obj_type = obj.attr("__class__").attr("__name__");
if (!py::hasattr(obj, "__parameter__")) {
MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '"
<< py::str(obj).cast<std::string>() << "' with type '"
<< py::str(obj_type).cast<std::string>() << "' at " << filename << ":" << line_no;
}

MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(block);
block->WriteVariable(var_name, assigned_node); block->WriteVariable(var_name, assigned_node);
MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString(); MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString();


+ 1
- 1
tests/ut/python/pynative_mode/test_insert_grad_of.py View File

@@ -124,9 +124,9 @@ def test_cell_assign():
class Mul(nn.Cell): class Mul(nn.Cell):
def __init__(self): def __init__(self):
super(Mul, self).__init__() super(Mul, self).__init__()
self.get_g = P.InsertGradientOf(self.save_gradient)
self.matrix_w = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_w") self.matrix_w = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_w")
self.matrix_g = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_g") self.matrix_g = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_g")
self.get_g = P.InsertGradientOf(self.save_gradient)


def save_gradient(self, dout): def save_gradient(self, dout):
self.matrix_g = dout + self.matrix_g self.matrix_g = dout + self.matrix_g


Loading…
Cancel
Save