Browse Source

!1945 [bug]fix bug in '=', use signature to support auto cast in assign.

Merge pull request !1945 from vlne-v1/I1JXUP-resnet50-thor-assign-fail
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
0e7839826e
4 changed files with 26 additions and 9 deletions
  1. +2
    -1
      mindspore/ccsrc/operator/ops.h
  2. +2
    -2
      mindspore/ccsrc/operator/ops_extends.cc
  3. +2
    -5
      mindspore/ccsrc/pipeline/parse/function_block.cc
  4. +20
    -1
      tests/ut/python/ops/test_signature.py

+ 2
- 1
mindspore/ccsrc/operator/ops.h View File

@@ -27,7 +27,8 @@ namespace mindspore {
// namespace to support primitive operators
namespace prim {
ValuePtr GetPythonOps(const std::string &op_name,
const std::string &module_name = "mindspore._extends.parse.standard_method");
const std::string &module_name = "mindspore._extends.parse.standard_method",
bool use_signature = false);

// Arithmetic
extern const PrimitivePtr kPrimScalarAdd;


+ 2
- 2
mindspore/ccsrc/operator/ops_extends.cc View File

@@ -23,10 +23,10 @@
namespace mindspore {
// namespace to support primitive operators
namespace prim {
ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name) {
ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name, bool use_signature) {
py::object obj = parse::python_adapter::GetPyFn(module_name, op_name);
ValuePtr node = nullptr;
bool succ = parse::ConvertData(obj, &node);
bool succ = parse::ConvertData(obj, &node, use_signature);
if (!succ) {
MS_LOG(EXCEPTION) << "get Python op " << op_name << " from " << module_name << " fail";
}


+ 2
- 5
mindspore/ccsrc/pipeline/parse/function_block.cc View File

@@ -322,12 +322,10 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {

ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
ValueNodePtr get_ref_origin_op = NewValueNode(prim::kPrimGetRefOrigin);
ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
const std::string primitive_name("assign");
const std::string module_name("mindspore.ops.functional");
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name));

ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
if (state_assign_.size() == 0 && auto_depends_.size() == 0) {
return;
}
@@ -336,8 +334,7 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
vec_states.emplace_back(make_tuple_op);
for (auto &item : state_assign_) {
auto source = ReadVariable(item.second);
auto origin = func_graph()->NewCNode({get_ref_origin_op, item.first});
auto assign = func_graph()->NewCNode({assign_op, origin, source});
auto assign = func_graph()->NewCNode({assign_op, item.first, source});
MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second;
vec_states.emplace_back(assign);
}


+ 20
- 1
tests/ut/python/ops/test_signature.py View File

@@ -47,7 +47,7 @@ class Net(nn.Cell):


def test_assign_through_cell():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE)
net = Net()
net.to_float(ms.float16)
net.add_flags_recursive(fp16=False)
@@ -57,6 +57,25 @@ def test_assign_through_cell():
net(None)


class AssignOp(nn.Cell):
def __init__(self):
super(AssignOp, self).__init__()
self.b = Parameter(initializer('ones', [5]), name='b')


def construct(self, w):
self.b = w
return w


def test_assign_by_operator():
context.set_context(mode=context.GRAPH_MODE)
net = AssignOp()
net.to_float(ms.float16)
input_data = Tensor(np.ones([5]).astype(np.float32))
net(input_data)


class NetScatterNdUpdate(nn.Cell):
def __init__(self):
super(NetScatterNdUpdate, self).__init__()


Loading…
Cancel
Save