Browse Source

fix mobilenetv2 loss error in pynative mode

tags/v1.0.0
chujinjin 5 years ago
parent
commit
62bfaf7e91
2 changed files with 16 additions and 2 deletions
  1. +15
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  2. +1
    -0
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h

+ 15
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -776,9 +776,14 @@ void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr
if (iter != op_forward_map_.end()) { if (iter != op_forward_map_.end()) {
return; return;
} }
op_forward_map_[id] = value;
auto tuple_info_iter = obj_to_forward_id_tuple_info_.find(id);
ValuePtr temp_value = value;
if (tuple_info_iter != obj_to_forward_id_tuple_info_.end()) {
temp_value = tuple_info_iter->second;
}
op_forward_map_[id] = temp_value;
MS_LOG(DEBUG) << "Save op forward value: " MS_LOG(DEBUG) << "Save op forward value: "
<< "(" << id << "), " << value;
<< "(" << id << "), " << temp_value;
} }


void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) { void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) {
@@ -808,6 +813,14 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CN
cnode->set_forward(value, op_id); cnode->set_forward(value, op_id);
++op_id_map_[id]; ++op_id_map_[id];
auto out_id = GetId(out_real); auto out_id = GetId(out_real);
if (py::isinstance<py::tuple>(out_real)) {
auto tuple_item = py::cast<py::tuple>(out_real);
for (size_t i = 0; i < tuple_item.size(); i++) {
auto tuple_item_id = GetId(tuple_item[i]);
obj_to_forward_id_[tuple_item_id] = op_id;
}
obj_to_forward_id_tuple_info_[op_id] = value;
}
obj_to_forward_id_[out_id] = op_id; obj_to_forward_id_[out_id] = op_id;
} }
} }


+ 1
- 0
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -154,6 +154,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std::unordered_map<std::string, ValuePtr> op_forward_map_; std::unordered_map<std::string, ValuePtr> op_forward_map_;
std::unordered_map<std::string, size_t> op_id_map_; std::unordered_map<std::string, size_t> op_id_map_;
std::unordered_map<std::string, std::string> obj_to_forward_id_; std::unordered_map<std::string, std::string> obj_to_forward_id_;
std::unordered_map<std::string, ValuePtr> obj_to_forward_id_tuple_info_;
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_; std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
std::unordered_map<std::string, FuncGraphPtr> df_builder_map_; std::unordered_map<std::string, FuncGraphPtr> df_builder_map_;
// the stack that records the context of graph created, the bottom is the top graph // the stack that records the context of graph created, the bottom is the top graph


Loading…
Cancel
Save