Browse Source

!6524 Pynative fix bug of tuple set item index wrong

Merge pull request !6524 from JoyLvliang/pynative-fix-bug-of-tuple-set-item-index-wrong
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
851d3d9dff
2 changed files with 1 additions and 7 deletions
  1. +1
    -1
      mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h
  2. +0
    -6
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc

+ 1
- 1
mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h View File

@@ -255,7 +255,7 @@ class PynativeEliminater : public OptimizerCaller {
MS_LOG(DEBUG) << "Start FillZero";
ValuePtr out = nullptr;
if (value->isa<Int32Imm>()) {
return MakeValue(0);
return value;
}

if (value->isa<tensor::Tensor>()) {


+ 0
- 6
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -70,7 +70,6 @@ const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "InsertG

namespace mindspore {
namespace pynative {

static std::shared_ptr<session::SessionBasic> session = nullptr;
PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
std::mutex PynativeExecutor::instance_lock_;
@@ -1213,7 +1212,6 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
}
return graph_info_map_[df_builder_].param_map[obj_id].first;
}

// if input is graph output
if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) {
// op(x, y)
@@ -1227,20 +1225,16 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
// out = op((x, y))
// out = cell((x, y))
auto tuple = obj.cast<py::tuple>();

// cell((1,2)): support not mix (scalar, tensor)
if (!tuple.empty() && !py::isinstance<tensor::Tensor>(tuple[0])) {
return MakeValueNode(obj, obj_id);
}

std::vector<AnfNodePtr> args;
args.push_back(NewValueNode(prim::kPrimMakeTuple));

auto tuple_size = static_cast<int>(tuple.size());
for (int i = 0; i < tuple_size; i++) {
args.push_back(GetInput(tuple[i], false));
}

auto cnode = curr_g_->NewCNode(args);
set_obj_node_map(curr_g_, GetId(obj), cnode);
node = cnode;


Loading…
Cancel
Save