|
|
|
@@ -994,7 +994,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { |
|
|
|
// out = op(cell1(x, y)) |
|
|
|
// out = op(cell1(x, y)[0]) |
|
|
|
node = GetObjNode(obj, obj_id); |
|
|
|
} else if (py::isinstance<py::tuple>(obj)) { |
|
|
|
} else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) { |
|
|
|
// out = op((x, y)) |
|
|
|
// out = cell((x, y)) |
|
|
|
auto tuple = obj.cast<py::tuple>(); |
|
|
|
@@ -1106,6 +1106,23 @@ void PynativeExecutor::CleanPreMemoryInValueNode(const std::string &cell_id) { |
|
|
|
top_cell_id_ = cell_id; |
|
|
|
return; |
|
|
|
} |
|
|
|
if (dynamic_cell_) { |
|
|
|
std::set<std::string> forward_op_tensor_id; |
|
|
|
for (const auto &elem : cell_op_index_with_tensor_id_[top_cell_id_]) { |
|
|
|
const auto &tensor_id_list = elem.second; |
|
|
|
for (const auto &tensor_id : tensor_id_list) { |
|
|
|
forward_op_tensor_id.emplace(tensor_id); |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto &tensor : all_value_node_tensors_) { |
|
|
|
if (tensor->device_address() != nullptr && |
|
|
|
forward_op_tensor_id.find(tensor->id()) != forward_op_tensor_id.end()) { |
|
|
|
tensor->device_address()->ClearDeviceMemory(); |
|
|
|
tensor->set_device_address(nullptr); |
|
|
|
} |
|
|
|
} |
|
|
|
all_value_node_tensors_.clear(); |
|
|
|
} |
|
|
|
const auto &tensor_id_with_tensor = cell_tensor_id_with_tensor_[top_cell_id_]; |
|
|
|
for (const auto &elem : tensor_id_with_tensor) { |
|
|
|
const auto &tensors_in_value_node = elem.second; |
|
|
|
@@ -2117,6 +2134,37 @@ std::string PynativeExecutor::GetGradCellId(bool has_sens, const py::object &cel |
|
|
|
return cell_id; |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::SaveAllValueNodeTensors(const FuncGraphPtr &graph) { |
|
|
|
std::unordered_set<tensor::TensorPtr> all_value_node_tensors; |
|
|
|
auto trace_function = [&all_value_node_tensors](const AnfNodePtr &anf_node) { |
|
|
|
auto value = GetValueNode(anf_node); |
|
|
|
if (value) { |
|
|
|
if (value->isa<tensor::Tensor>()) { |
|
|
|
auto tensor = value->cast<tensor::TensorPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
if (tensor->device_address()) { |
|
|
|
all_value_node_tensors.emplace(tensor); |
|
|
|
} |
|
|
|
} else if (value->isa<ValueTuple>()) { |
|
|
|
auto tuple = value->cast<ValueTuplePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple); |
|
|
|
for (size_t i = 0; i < tuple->size(); i++) { |
|
|
|
if ((*tuple)[i]->isa<tensor::Tensor>()) { |
|
|
|
auto tensor = (*tuple)[i]->cast<tensor::TensorPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
if (tensor->device_address()) { |
|
|
|
all_value_node_tensors.emplace(tensor); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return FOLLOW; |
|
|
|
}; |
|
|
|
(void)TopoSort(graph->get_return(), SuccDeeperSimple, trace_function); |
|
|
|
all_value_node_tensors_ = all_value_node_tensors; |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, |
|
|
|
const py::args &args) { |
|
|
|
auto size = args.size(); |
|
|
|
@@ -2158,6 +2206,9 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje |
|
|
|
resource->results()[pipeline::kBackend] = compile::CreateBackend(); |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Start opt"; |
|
|
|
if (dynamic_cell_) { |
|
|
|
SaveAllValueNodeTensors(resource->func_graph()); |
|
|
|
} |
|
|
|
PynativeOptimizeAction(resource); |
|
|
|
SaveTensorsInValueNode(resource); |
|
|
|
TaskEmitAction(resource); |
|
|
|
|