diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index 3998b6c9fe..9cb0d301e1 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -141,11 +141,6 @@ void RunOpsInGraphTask::Run() { session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_); } -void CleanUselessTensorsTask::Run() { - MS_EXCEPTION_IF_NULL(session_); - session_->CleanUselessTensorsImpl(useless_tensors_); -} - void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); } void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); } @@ -392,15 +387,6 @@ void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, *outputs = task->outputs_; } -void Executor::CleanUselessTensors(const SessionPtr &session, - const std::shared_ptr> &useless_tensors) { - MS_EXCEPTION_IF_NULL(useless_tensors); - auto task = std::make_shared(); - task->session_ = session; - task->useless_tensors_ = useless_tensors; - SyncRunTask(task); -} - bool Executor::CreateCommGroup(const std::string &group_name, std::vector ranks) { auto task = std::make_shared(); task->group_name_ = group_name; diff --git a/mindspore/ccsrc/backend/session/executor.h b/mindspore/ccsrc/backend/session/executor.h index d0085f5263..af501433d3 100644 --- a/mindspore/ccsrc/backend/session/executor.h +++ b/mindspore/ccsrc/backend/session/executor.h @@ -46,8 +46,7 @@ enum TaskType { kRunOp, kCreateCommGroup, kDestroyCommGroup, - kRunOpsInGraph, - kCleanUselessTensors + kRunOpsInGraph }; class Task { @@ -110,14 +109,6 @@ class RunOpsInGraphTask : public Task { GraphId graph_id_{0}; }; -class CleanUselessTensorsTask : public Task { - public: - CleanUselessTensorsTask() { type_ = kCleanUselessTensors; } - ~CleanUselessTensorsTask() override = default; - void Run() override; - std::shared_ptr> useless_tensors_{nullptr}; -}; - class RunOpTask : public Task { public: RunOpTask() { type_ = kRunOp; } @@ -175,8 +166,6 @@ class Executor { const std::vector &tensors_mask); void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); - void CleanUselessTensors(const SessionPtr &session, - const std::shared_ptr> &useless_tensors); bool CreateCommGroup(const std::string &group_name, std::vector ranks); bool DestroyCommGroup(const std::string &group_name); void OnEvent(const ExecutorEvent &event); diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index df0b4e7ecf..d8d87fe615 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1657,11 +1657,6 @@ void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vectorRunOpsInGraph(shared_from_this(), graph_id, inputs, outputs); } -void SessionBasic::CleanUselessTensors(const std::shared_ptr> &useless_tensors) { - MS_EXCEPTION_IF_NULL(executor_); - executor_->CleanUselessTensors(shared_from_this(), useless_tensors); -} - void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { MS_EXCEPTION_IF_NULL(executor_); executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs); @@ -1710,22 +1705,6 @@ void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull &ro root_graph->UpdateGraphDynamicAttr(); } -void SessionBasic::CleanUselessTensorsImpl(const std::shared_ptr> &useless_tensors) { - auto ms_context = MsContext::GetInstance(); - std::string device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); - if (device_target == "CPU") { - return; - } - for (const auto &tensor : *useless_tensors) { - MS_EXCEPTION_IF_NULL(tensor); - const auto &shape = tensor->shape(); - if (!shape.empty()) { - // The address of scalar value node does not need to be deleted - tensor->set_device_address(nullptr); - } - } -} - bool SessionBasic::IsGetNextGraph(const GraphId &graph_id, std::string *channel_name) { auto kernel_graph = graphs_[graph_id]; MS_EXCEPTION_IF_NULL(kernel_graph); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index a6360dfd0f..a042e00d91 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -82,7 +82,6 @@ class SessionBasic : public std::enable_shared_from_this { void RunOp(OpRunInfo *, const GraphInfo &, std::vector *input_tensors, VectorRef *outputs, const std::vector &tensors_mask); void RunOpsInGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); - void CleanUselessTensors(const std::shared_ptr> &useless_tensors); virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); @@ -142,7 +141,6 @@ class SessionBasic : public std::enable_shared_from_this { friend class RunGraphTask; friend class RunOpTask; friend class RunOpsInGraphTask; - friend class CleanUselessTensorsTask; virtual bool IsSupportSummary() { return true; } virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector &input_tensors, VectorRef *outputs, @@ -164,7 +162,6 @@ class SessionBasic : public std::enable_shared_from_this { const std::vector &tensors_mask) {} virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) {} - void CleanUselessTensorsImpl(const std::shared_ptr> &useless_tensors); void RunInfer(NotNull func_graph, const std::vector &inputs); virtual void SetSummaryNodes(KernelGraph *graph); diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 686097fb33..811a53e4e9 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -253,6 +253,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { k_app = k_graph_->NewCNode(inputs); } ReplaceEquivdout(k_app, cnode_morph); + cnode_morph->clear_inputs_value(); cnode_morph->set_forward(nullptr, ""); for (size_t i = 0; i < param_adjoints.size(); ++i) { param_adjoints[i]->RegisterKUser(k_app, i); @@ -387,7 +388,6 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor MS_EXCEPTION_IF_NULL(out_node); out_node->set_value(GenNewTensor(manager, out_node, out_node->value(), need_replace_forward)); // clear resource - cnode_morph->clear_inputs_value(); fg->ClearAllManagerInfo(); func_graph->ClearAllManagerInfo(); } diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index 249678e30c..3873dc68bf 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -92,6 +92,7 @@ const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj"; const char NAMED_PRIMITIVE_LEN[] = "len"; const char NAMED_PRIMITIVE_BODY[] = "body"; const char NAMED_PRIMITIVE_ASSIGN[] = "Assign"; +const char NAMED_PRIMITIVE_AUGASSIGN[] = "AugAssign"; const char NAMED_PRIMITIVE_FOR[] = "For"; const char NAMED_PRIMITIVE_IF[] = "If"; const char NAMED_PRIMITIVE_WHILE[] = "While"; @@ -105,6 +106,7 @@ const char NAMED_PRIMITIVE_ATTRIBUTE[] = "Attribute"; const char NAMED_PRIMITIVE_COMPARE[] = "Compare"; const char NAMED_PRIMITIVE_NAMECONSTANT[] = "NameConstant"; const char NAMED_PRIMITIVE_COMPARATORS[] = "comparators"; +const char NAMED_PRIMITIVE_TARGET[] = "target"; const char NAMED_PRIMITIVE_SLICE[] = "slice"; const char NAMED_PRIMITIVE_NAME[] = "Name"; const char NAMED_PRIMITIVE_NUM[] = "Num"; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index ab24aeafd8..ea461d66f7 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -622,7 +622,7 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { op_exec_info->op_name = op_name; if (grad_flag()) { int64_t graph_id = graph_id_; - auto resource = GetResource(); + auto resource = GetResource(top_cell_id_); if (resource != nullptr) { MS_LOG(DEBUG) << "Get resource ptr " << resource.get(); auto it = resource->results().find(pipeline::kPynativeGraphId); @@ -1007,21 +1007,21 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex MS_EXCEPTION_IF_NULL(output_value); std::vector output_tensors; TensorValueToTensor(output_value, &output_tensors); - if (op_index_with_tensor_id_.find(op_index) == op_index_with_tensor_id_.end()) { + if (cell_op_index_with_tensor_id_[top_cell_id_].find(op_index) == cell_op_index_with_tensor_id_[top_cell_id_].end()) { // first step std::for_each(output_tensors.begin(), output_tensors.end(), [&](const tensor::TensorPtr &tensor) { - op_index_with_tensor_id_[op_index].emplace_back(tensor->id()); + cell_op_index_with_tensor_id_[top_cell_id_][op_index].emplace_back(tensor->id()); }); return; } auto ms_context = MsContext::GetInstance(); auto target = ms_context->get_param(MS_CTX_DEVICE_TARGET); - const auto &tensor_id_list = op_index_with_tensor_id_[op_index]; + const auto &tensor_id_list = cell_op_index_with_tensor_id_[top_cell_id_][op_index]; for (size_t i = 0; i < tensor_id_list.size(); ++i) { auto tensor_id = tensor_id_list[i]; - if (tensor_id_with_tensor_.find(tensor_id) != tensor_id_with_tensor_.end()) { + if (cell_tensor_id_with_tensor_[top_cell_id_].find(tensor_id) != cell_tensor_id_with_tensor_[top_cell_id_].end()) { auto &new_tensor = output_tensors[i]; - auto &tensors_in_value_node = tensor_id_with_tensor_[tensor_id]; + auto &tensors_in_value_node = cell_tensor_id_with_tensor_[top_cell_id_][tensor_id]; std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) { MS_LOG(DEBUG) << "Debug address: Replace forward old tensor obj " << tensor.get() << ", tensor id " << tensor->id() << ", device address " << tensor->device_address().get() @@ -1050,7 +1050,15 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { MS_EXCEPTION_IF_NULL(resource); - tensor_id_with_tensor_.clear(); + std::set forward_op_tensor_id; + for (const auto &elem : cell_op_index_with_tensor_id_[top_cell_id_]) { + const auto &tensor_id_list = elem.second; + for (const auto &tensor_id : tensor_id_list) { + forward_op_tensor_id.emplace(tensor_id); + } + } + + cell_tensor_id_with_tensor_[top_cell_id_].clear(); const auto &func_graph = resource->func_graph(); const auto &value_node_list = func_graph->value_nodes(); for (const auto &elem : value_node_list) { @@ -1059,8 +1067,9 @@ void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { std::vector tensors; TensorValueToTensor(value_node->value(), &tensors); for (const auto &tensor : tensors) { - if (tensor->device_address() != nullptr) { - tensor_id_with_tensor_[tensor->id()].emplace_back(tensor); + if (tensor->device_address() != nullptr && + forward_op_tensor_id.find(tensor->id()) != forward_op_tensor_id.end()) { + cell_tensor_id_with_tensor_[top_cell_id_][tensor->id()].emplace_back(tensor); MS_LOG(DEBUG) << "Debug address: Save forward tensor obj " << tensor.get() << ", tensor id " << tensor->id() << ", device address " << tensor->device_address().get(); } @@ -1068,16 +1077,22 @@ void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { } } -void PynativeExecutor::CleanTensorsInValueNode() { - // Only need clean in ms backend policy and session should not be nullptr in ms backend. - if (session == nullptr) { +void PynativeExecutor::CleanPreMemoryInValueNode(const std::string &cell_id) { + auto ms_context = MsContext::GetInstance(); + std::string device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); + if (device_target == "CPU") { + top_cell_id_ = cell_id; return; } - auto useless_tensors = std::make_shared>(); - for (const auto &id_tensor_pair : tensor_id_with_tensor_) { - std::copy(id_tensor_pair.second.begin(), id_tensor_pair.second.end(), std::back_inserter(*useless_tensors)); + const auto &tensor_id_with_tensor = cell_tensor_id_with_tensor_[top_cell_id_]; + for (const auto &elem : tensor_id_with_tensor) { + const auto &tensors_in_value_node = elem.second; + for (const auto &tensor : tensors_in_value_node) { + MS_EXCEPTION_IF_NULL(tensor); + tensor->set_device_address(nullptr); + } } - session->CleanUselessTensors(useless_tensors); + top_cell_id_ = cell_id; } AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { @@ -1468,6 +1483,12 @@ void PynativeExecutor::SubNestedGradOrder() { } bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) { + auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfo &value) { + return value.cell_id == cell_id && value.is_dynamic_cell; + }); + if (it != top_cell_list_.end()) { + return false; + } return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id, is_grad](const CellInfo &value) { return value.cell_id == cell_id && (!is_grad || value.is_grad); }); @@ -1590,6 +1611,22 @@ bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr self.b and changed self.a or self.b + if (left == parse::NAMED_PRIMITIVE_ATTRIBUTE && right == parse::NAMED_PRIMITIVE_ATTRIBUTE) { + auto left_value = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE); + std::string left_variable; + if (py::hasattr(left_node, "attr") && py::hasattr(left_value, "id")) { + left_variable = py::cast(left_value.attr("id")) + py::cast(left_node.attr("attr")); + } + auto right_value = parse::python_adapter::GetPyObjAttr(comparators_node[0], parse::NAMED_PRIMITIVE_VALUE); + std::string right_variable; + if (py::hasattr(comparators_node[0], "attr") && py::hasattr(right_value, "id")) { + right_variable = + py::cast(right_value.attr("id")) + py::cast(comparators_node[0].attr("attr")); + } + return ParseBodyContext(ast, node, {left_variable, right_variable}); + } + // if a[0] if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) { py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE); left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR); @@ -1629,6 +1666,34 @@ bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr &ast, const py::object &node, + const std::vector &compare_prim) { + MS_LOG(DEBUG) << "Parse augassign expr"; + bool ret = false; + if (compare_prim.empty()) { + return ret; + } + py::object target_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TARGET); + if (py::isinstance(target_node)) { + MS_LOG(DEBUG) << "Parse target node is none!"; + return ret; + } + py::object value_node = parse::python_adapter::GetPyObjAttr(target_node, parse::NAMED_PRIMITIVE_VALUE); + if (py::isinstance(value_node)) { + MS_LOG(DEBUG) << "Parse value node is none!"; + return ret; + } + std::string assign_prim; + if (py::hasattr(target_node, "attr") && py::hasattr(value_node, "id")) { + assign_prim = py::cast(value_node.attr("id")) + py::cast(target_node.attr("attr")); + } + auto iter = std::find(compare_prim.begin(), compare_prim.end(), assign_prim); + if (iter != compare_prim.end()) { + ret = true; + } + return ret; +} + bool PynativeExecutor::ParseForExprNode(const std::shared_ptr &ast, const py::object &node) { MS_LOG(DEBUG) << "Parse for expr"; py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY); @@ -1649,7 +1714,8 @@ bool PynativeExecutor::ParseForExprNode(const std::shared_ptr & return false; } -bool PynativeExecutor::ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node) { +bool PynativeExecutor::ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node, + const std::vector &compare_prim) { MS_EXCEPTION_IF_NULL(ast); py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY); if (py::isinstance(func_obj)) { @@ -1665,6 +1731,8 @@ bool PynativeExecutor::ParseBodyContext(const std::shared_ptr & const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT); if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) { ret = ParseAssignExprNode(ast, node); + } else if (node_name == parse::NAMED_PRIMITIVE_AUGASSIGN) { + ret = ParseAugAssignExprNode(ast, node, compare_prim); } else if (node_name == parse::NAMED_PRIMITIVE_FOR) { ret = ParseForExprNode(ast, node); } else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) { @@ -1719,17 +1787,18 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg MS_LOG(EXCEPTION) << "Top cell list is empty"; } if (IsTopGraph(cell_id)) { + // Clear previous step resource op_index_map_.clear(); + CleanPreMemoryInValueNode(cell_id); } MS_LOG(INFO) << "NewGraph already compiled"; return; } // init resource for constructing forward graph and grad graph - auto g = std::make_shared(); - curr_g_ = g; + curr_g_ = std::make_shared(); ClearResidualRes(cell_id); if (graph_stack_.empty() && !IsBpropGraph(cell_id)) { - MakeNewTopGraph(cell_id, args, g); + MakeNewTopGraph(cell_id, args, curr_g_); } PushCurrentGraphToStack(); if (graph_info_map_.find(curr_g_) == graph_info_map_.end()) { @@ -1738,7 +1807,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg } for (size_t i = 0; i < args.size(); ++i) { auto param = args[i]; - auto new_param = g->add_parameter(); + auto new_param = curr_g_->add_parameter(); std::string param_id = GetId(param); SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true); SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param); @@ -1747,6 +1816,13 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg // check whether the construct of cell will be changed if (!dynamic_cell_) { dynamic_cell_ = IsDynamicCell(cell); + if (dynamic_cell_) { + auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), + [&](const TopCellInfo &value) { return value.cell_id == top_cell_id_; }); + if (it != top_cell_list_.end()) { + it->is_dynamic_cell = dynamic_cell_; + } + } MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << dynamic_cell_; } } @@ -1760,16 +1836,17 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar } } } - // Clear runop pre + // Clear resource in old top cell auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; }); if (it != top_cell_list_.end()) { top_cell_list_.erase(it); } - dynamic_cell_ = false; op_index_map_.clear(); - op_index_with_tensor_id_.clear(); + CleanPreMemoryInValueNode(cell_id); + // Init resource for new top cell + dynamic_cell_ = false; auto df_builder = std::make_shared(); GraphInfo graph_info = GraphInfo(cell_id); graph_info_map_.emplace(df_builder, graph_info); @@ -2359,7 +2436,6 @@ py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, MS_LOG(DEBUG) << "Eval run " << backend; set_grad_runing(true); BaseRef value = (*run)(arg_list); - CleanTensorsInValueNode(); set_grad_runing(false); MS_LOG(DEBUG) << "Eval run end " << value.ToString(); auto out = BaseRefToPyData(value); @@ -2506,8 +2582,8 @@ void PynativeExecutor::ClearRes() { cell_graph_list_.clear(); top_cell_list_.clear(); op_index_map_.clear(); - op_index_with_tensor_id_.clear(); - tensor_id_with_tensor_.clear(); + cell_op_index_with_tensor_id_.clear(); + cell_tensor_id_with_tensor_.clear(); cell_dynamic_map_.clear(); prim_abs_list_.clear(); std::stack().swap(graph_stack_); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 12d854429a..16dc576c15 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -52,6 +52,8 @@ struct PrimAbsInfo { using AbstractListMap = std::unordered_map; +using OpIndexWithTensorId = std::unordered_map>; +using TensorIdWithTensor = std::unordered_map>; py::tuple RunOp(const py::args &args); @@ -87,6 +89,7 @@ struct TopCellInfo { FuncGraphPtr df_builder; FuncGraphPtr bg; // Backward graph std::string cell_id; + bool is_dynamic_cell{false}; TopCellInfo() = default; TopCellInfo(ResourcePtr r, FuncGraphPtr df, FuncGraphPtr backward_graph, std::string cellid) : resource(std::move(r)), df_builder(std::move(df)), bg(std::move(backward_graph)), cell_id(std::move(cellid)) {} @@ -154,9 +157,12 @@ class PynativeExecutor : public std::enable_shared_from_this { bool IsDynamicCell(const py::object &cell); std::string GetCellInfo(const py::object &cell); void ParseInputArgs(const std::shared_ptr &ast, const py::object &fn_node); - bool ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node); + bool ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node, + const std::vector &compare_prim = {}); bool ParseIfWhileExprNode(const std::shared_ptr &ast, const py::object &node); bool ParseAssignExprNode(const std::shared_ptr &ast, const py::object &node); + bool ParseAugAssignExprNode(const std::shared_ptr &ast, const py::object &node, + const std::vector &compare_prim = {}); bool ParseForExprNode(const std::shared_ptr &ast, const py::object &node); std::string ParseNodeName(const std::shared_ptr &ast, const py::object &node, parse::AstMainType type); @@ -190,7 +196,7 @@ class PynativeExecutor : public std::enable_shared_from_this { // Update the abstract and device address info of value node and tensors in bprop graph void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); void SaveTensorsInValueNode(const ResourcePtr &resource); - void CleanTensorsInValueNode(); + void CleanPreMemoryInValueNode(const std::string &cell_id); // Construct grad graph void PushCurrentGraphToStack(); @@ -261,6 +267,7 @@ class PynativeExecutor : public std::enable_shared_from_this { static std::mutex instance_lock_; static int64_t graph_id_; size_t grad_order_{0}; + std::string top_cell_id_; bool grad_flag_{false}; bool dynamic_cell_{false}; bool grad_is_running_{false}; @@ -285,8 +292,8 @@ class PynativeExecutor : public std::enable_shared_from_this { // Used for runop and replace forward result of grad graph std::unordered_map op_index_map_; std::unordered_map obj_to_forward_id_; - std::unordered_map> op_index_with_tensor_id_; - std::unordered_map> tensor_id_with_tensor_; + std::unordered_map cell_op_index_with_tensor_id_; + std::unordered_map cell_tensor_id_with_tensor_; std::unordered_map node_abs_map_; std::unordered_map prim_abs_list_; }; diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index 8107322b93..1491c74e10 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -553,7 +553,7 @@ std::string Tensor::ToStringInternal(int limit_size) const { std::ostringstream buf; auto dtype = Dtype(); MS_EXCEPTION_IF_NULL(dtype); - buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ", value="; + buf << "Tensor(id=" << id_ << ", shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ", value="; if (limit_size <= 0 || DataSize() < limit_size) { // Only print data for small tensor. buf << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_, false); diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 6c85fcdf6b..b51391bbe0 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -361,7 +361,6 @@ class Cell(Cell_): _pynative_exec.end_graph(self, output, *inputs, **kwargs) for i, cell in enumerate(self.cells()): cell.set_grad(origin_grad[i]) - self._already_run = True return output def _add_attr(self, name, value): diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 37c5f4a3c8..2fac0f0b4e 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -182,6 +182,9 @@ class GradOperation(GradOperation_): sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input. If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False. + If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through + the location parameter or key-value pair parameter. If the value is transferred through the key-value pair + parameter, the key must be sens. Returns: The higher-order function which takes a function as argument and returns gradient function for it. @@ -311,16 +314,23 @@ class GradOperation(GradOperation_): def _pynative_forward_run(self, args, kwargs, fn): """ Pynative forward run to build grad graph. """ + new_kwargs = {} if self.sens_param: - args = args[:-1] + if not 'sens' in kwargs.keys(): + args = args[:-1] + new_kwargs = kwargs + else: + for key, value in kwargs.items(): + if key != 'sens': + new_kwargs[key] = value for arg in args: if not isinstance(arg, Tensor): raise TypeError("grad inputs should be tensor in pynative mode") if isinstance(fn, FunctionType): _pynative_exec.set_grad_flag(True) - _pynative_exec.new_graph(fn, *args, **kwargs) - output = fn(*args, **kwargs) - _pynative_exec.end_graph(fn, output, *args, **kwargs) + _pynative_exec.new_graph(fn, *args, **new_kwargs) + output = fn(*args, **new_kwargs) + _pynative_exec.end_graph(fn, output, *args, **new_kwargs) else: if fn.already_run and not fn.requires_grad: raise ValueError("obj must set_grad.") @@ -328,7 +338,7 @@ class GradOperation(GradOperation_): self.need_forward = True if self.need_forward: fn.set_grad() - fn(*args, **kwargs) + fn(*args, **new_kwargs) fn.already_run = False def __call__(self, fn, weights=None): diff --git a/tests/st/pynative/test_pynative_resnet50_ascend.py b/tests/st/pynative/test_pynative_resnet50_ascend.py index 652424a225..0aa9ce3aa8 100644 --- a/tests/st/pynative/test_pynative_resnet50_ascend.py +++ b/tests/st/pynative/test_pynative_resnet50_ascend.py @@ -404,10 +404,10 @@ def test_pynative_resnet50(): step = step + 1 if step > max_step: break - start_time = time.time() input_data = element["image"] input_label = element["label"] loss_output = net_with_criterion(input_data, input_label) + start_time = time.time() grads = train_network(input_data, input_label) optimizer(grads) end_time = time.time() diff --git a/tests/st/pynative/test_pynative_resnet50_gpu.py b/tests/st/pynative/test_pynative_resnet50_gpu.py index 9402e51853..064ee31017 100644 --- a/tests/st/pynative/test_pynative_resnet50_gpu.py +++ b/tests/st/pynative/test_pynative_resnet50_gpu.py @@ -403,10 +403,10 @@ def test_pynative_resnet50(): step = step + 1 if step > max_step: break - start_time = time.time() input_data = element["image"] input_label = element["label"] loss_output = net_with_criterion(input_data, input_label) + start_time = time.time() grads = train_network(input_data, input_label) optimizer(grads) end_time = time.time()