|
|
|
@@ -657,14 +657,17 @@ void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_i |
|
|
|
if (cnode != nullptr) { |
|
|
|
cnode->set_abstract(op_exec_info->abstract); |
|
|
|
} |
|
|
|
std::string obj_id = GetId(out_real); |
|
|
|
node_abs_map_[obj_id] = op_exec_info->abstract; |
|
|
|
|
|
|
|
// Save info for building grad graph |
|
|
|
if (grad()->grad_flag() && grad()->in_grad_process()) { |
|
|
|
std::string obj_id = GetId(out_real); |
|
|
|
node_abs_map_[obj_id] = op_exec_info->abstract; |
|
|
|
grad()->SaveOutputNodeMap(obj_id, out_real, cnode); |
|
|
|
grad()->SaveAllResult(op_exec_info, cnode, out_real); |
|
|
|
// Update the abstract and device address of value node with tensor in grad graph |
|
|
|
UpdateAbstractAndDeviceAddress(op_exec_info, out_real); |
|
|
|
} else { |
|
|
|
node_abs_map_.clear(); |
|
|
|
} |
|
|
|
*ret = out_real; |
|
|
|
} |
|
|
|
@@ -810,8 +813,8 @@ abstract::AbstractBasePtr ForwardExecutor::CheckConstValue(const PrimitivePyPtr |
|
|
|
MS_EXCEPTION_IF_NULL(new_abs); |
|
|
|
new_abs = new_abs->Broaden(config); |
|
|
|
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config; |
|
|
|
node_abs_map_[id] = new_abs; |
|
|
|
} |
|
|
|
node_abs_map_[id] = new_abs; |
|
|
|
} |
|
|
|
return new_abs; |
|
|
|
} |
|
|
|
|