Browse Source

!431 fix cell output issue in pynative mode

Merge pull request !431 from wangqiuliang/fix-cell-ouput-issue-in-pynative
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
59eb8aa053
3 changed files with 8 additions and 5 deletions
  1. +3
    -3
      mindspore/ccsrc/pynative/pynative_execute.cc
  2. +1
    -1
      mindspore/ccsrc/session/anf_runtime_algorithm.cc
  3. +4
    -1
      mindspore/nn/cell.py

+ 3
- 3
mindspore/ccsrc/pynative/pynative_execute.cc View File

@@ -39,7 +39,7 @@

const char SINGLE_OP_GRAPH[] = "single_op_graph";
// primitive unable to infer value for constant input in PyNative mode
const std::unordered_set<std::string> vm_operators = {"partial", "depend"};
const std::unordered_set<std::string> vm_operators = {"partial", "depend", "make_ref"};

namespace mindspore {
namespace pynative {
@@ -141,7 +141,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) {
op_exec_info->op_inputs = py_args;
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
MS_LOG(ERROR) << "op:" << op_exec_info->op_name << " inputs size not equal op_mask";
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
return nullptr;
}
return op_exec_info;
@@ -163,7 +163,7 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr& op_exec_info) {
// get prim and abstract info
(void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" +
op_exec_info->abstract->ToString());
MS_LOG(INFO) << "graph info [" << graph_info << "]";
MS_LOG(INFO) << "Graph info [" << graph_info << "]";
return graph_info;
}



+ 1
- 1
mindspore/ccsrc/session/anf_runtime_algorithm.cc View File

@@ -457,7 +457,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_
} else if (tuple_i->isa<Number>()) {
return tuple_i->type_id();
} else {
MS_LOG(EXCEPTION) << "Not support type " << tuple_i->ToString();
MS_LOG(WARNING) << "Not support type " << tuple_i->ToString();
return tuple_i->type_id();
}
} else if (type_ptr->isa<Number>()) {


+ 4
- 1
mindspore/nn/cell.py View File

@@ -140,7 +140,10 @@ class Cell:
if context.get_context("mode") == context.GRAPH_MODE:
out = self.compile_and_run(*inputs)
return out
return self.construct(*inputs)
output = self.construct(*inputs)
if isinstance(output, Parameter):
output = output.data
return output

def __setattr__(self, name, value):
cells = self.__dict__.get('_cells')


Loading…
Cancel
Save