| @@ -25,6 +25,7 @@ | |||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "pipeline/jit/resource.h" | #include "pipeline/jit/resource.h" | ||||
| #include "pipeline/pynative/pynative_execute.h" | |||||
| #include "frontend/optimizer/ad/adjoint.h" | #include "frontend/optimizer/ad/adjoint.h" | ||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| @@ -218,7 +219,8 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { | |||||
| TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info())); | TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info())); | ||||
| auto k_app = k_graph_->NewCNode(inputs); | auto k_app = k_graph_->NewCNode(inputs); | ||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| ReplaceEquivdout(k_app, cnode_morph->forward()); | |||||
| ReplaceEquivdout(k_app, cnode_morph); | |||||
| cnode_morph->set_forward(nullptr, ""); | |||||
| for (size_t i = 0; i < param_adjoints.size(); ++i) { | for (size_t i = 0; i < param_adjoints.size(); ++i) { | ||||
| param_adjoints[i]->RegisterKUser(k_app, i); | param_adjoints[i]->RegisterKUser(k_app, i); | ||||
| } | } | ||||
| @@ -240,7 +242,9 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { | |||||
| return node_adjoint; | return node_adjoint; | ||||
| } | } | ||||
| void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward) { | |||||
| void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) { | |||||
| auto forward = cnode_morph->forward().first; | |||||
| auto forward_id = cnode_morph->forward().second; | |||||
| if (forward == nullptr) { | if (forward == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -265,10 +269,44 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward) | |||||
| auto equivdout = cnode_input->cast<CNodePtr>(); | auto equivdout = cnode_input->cast<CNodePtr>(); | ||||
| auto func_graph = GetValueNode<FuncGraphPtr>(input_fg); | auto func_graph = GetValueNode<FuncGraphPtr>(input_fg); | ||||
| auto manager = Manage({fg, func_graph}, false); | auto manager = Manage({fg, func_graph}, false); | ||||
| auto ref_size = manager->node_users()[equivdout].size(); | |||||
| auto forward_value = forward; | |||||
| if (!forward_id.empty() && ref_size > 1) { | |||||
| auto inst = pynative::PynativeExecutor::GetInstance(); | |||||
| inst->SaveOpForwardValue(forward_id, forward_value); | |||||
| } | |||||
| if (ref_size < 2) { | |||||
| auto tensor = forward->cast<tensor::TensorPtr>(); | |||||
| if (tensor != nullptr) { | |||||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape()); | |||||
| forward_value = new_tensor; | |||||
| } | |||||
| } | |||||
| MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward; | MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward; | ||||
| auto value_node = NewValueNode(forward); | |||||
| auto value_node = NewValueNode(forward_value); | |||||
| value_node->set_has_new_value(true); | value_node->set_has_new_value(true); | ||||
| manager->Replace(equivdout, value_node); | manager->Replace(equivdout, value_node); | ||||
| auto paras = fg->parameters(); | |||||
| auto inputs_value = cnode_morph->inputs_value(); | |||||
| if (inputs_value.size() == 0) { | |||||
| return; | |||||
| } | |||||
| if (inputs_value.size() != paras.size()) { | |||||
| MS_LOG(EXCEPTION) << "Parameter size:" << paras.size() << " is not equal to inputs size:" << inputs_value.size(); | |||||
| } | |||||
| for (size_t i = 0; i < paras.size(); i++) { | |||||
| auto para_ref_size = manager->node_users()[paras[i]].size(); | |||||
| auto input_value = inputs_value[i]; | |||||
| if (para_ref_size > 0 && input_value.first != nullptr) { | |||||
| MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first; | |||||
| auto inst = pynative::PynativeExecutor::GetInstance(); | |||||
| inst->SaveOpForwardValue(input_value.second, input_value.first); | |||||
| auto input_value_node = NewValueNode(input_value.first); | |||||
| manager->Replace(paras[i], input_value_node); | |||||
| } | |||||
| } | |||||
| cnode_morph->clear_inputs_value(); | |||||
| return; | |||||
| } | } | ||||
| bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { | bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { | ||||
| @@ -95,7 +95,7 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> { | |||||
| // Update k hole with adjoint_definition, only applied in recursive case. | // Update k hole with adjoint_definition, only applied in recursive case. | ||||
| void UpdateAdjoint(const AdjointPtr &adjoint_definition); | void UpdateAdjoint(const AdjointPtr &adjoint_definition); | ||||
| void CallDoutHoleOnTape(); | void CallDoutHoleOnTape(); | ||||
| void ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward); | |||||
| void ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph); | |||||
| std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_; | std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_; | ||||
| // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer. | // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer. | ||||
| @@ -724,18 +724,14 @@ void PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::ob | |||||
| set_pyobj(curr_g_, obj_id); | set_pyobj(curr_g_, obj_id); | ||||
| } | } | ||||
| void PynativeExecutor::SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value) { | |||||
| auto id = GetOpId(op_exec_info); | |||||
| int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int>(); | |||||
| auto op = std::to_string(graph_id) + id; | |||||
| op.append(std::to_string(op_id_map_[id])); | |||||
| auto iter = op_forward_map_.find(op); | |||||
| void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr &value) { | |||||
| auto iter = op_forward_map_.find(id); | |||||
| if (iter != op_forward_map_.end()) { | if (iter != op_forward_map_.end()) { | ||||
| return; | return; | ||||
| } | } | ||||
| op_forward_map_[op] = value; | |||||
| ++op_id_map_[id]; | |||||
| MS_LOG(DEBUG) << "Save: " << op_exec_info->op_name << "(" << op << "), " << value; | |||||
| op_forward_map_[id] = value; | |||||
| MS_LOG(DEBUG) << "Save op forward value: " | |||||
| << "(" << id << "), " << value; | |||||
| } | } | ||||
| void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) { | void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) { | ||||
| @@ -748,9 +744,25 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CN | |||||
| } | } | ||||
| auto value = PyAttrValue(out_real); | auto value = PyAttrValue(out_real); | ||||
| if (cnode != nullptr) { | if (cnode != nullptr) { | ||||
| cnode->set_forward(value); | |||||
| size_t size = op_exec_info->op_inputs.size(); | |||||
| for (size_t i = 0; i < size; i++) { | |||||
| auto obj = op_exec_info->op_inputs[i]; | |||||
| auto obj_id = GetId(obj); | |||||
| if (obj_to_forward_id_.find(obj_id) != obj_to_forward_id_.end()) { | |||||
| cnode->add_input_value(PyAttrValue(obj), obj_to_forward_id_[obj_id]); | |||||
| } else { | |||||
| cnode->add_input_value(nullptr, ""); | |||||
| } | |||||
| } | |||||
| std::string id = GetOpId(op_exec_info); | |||||
| int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int>(); | |||||
| auto op_id = std::to_string(graph_id) + id; | |||||
| op_id.append(std::to_string(op_id_map_[id])); | |||||
| cnode->set_forward(value, op_id); | |||||
| ++op_id_map_[id]; | |||||
| auto out_id = GetId(out_real); | |||||
| obj_to_forward_id_[out_id] = op_id; | |||||
| } | } | ||||
| SaveOpForwardValue(op_exec_info, value); | |||||
| } | } | ||||
| AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { | AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { | ||||
| @@ -775,7 +787,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { | |||||
| node_abs_map_[id] = node->abstract(); | node_abs_map_[id] = node->abstract(); | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6); | MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6); | ||||
| node->cast<CNodePtr>()->set_forward(PyAttrValue(obj)); | |||||
| node->cast<CNodePtr>()->set_forward(PyAttrValue(obj), ""); | |||||
| return node; | return node; | ||||
| } | } | ||||
| @@ -1131,6 +1143,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje | |||||
| } | } | ||||
| } | } | ||||
| auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); | auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); | ||||
| graph_info_map_.erase(curr_g_); | |||||
| if (curr_g_ != top_g_) { | if (curr_g_ != top_g_) { | ||||
| Popp(); | Popp(); | ||||
| for (size_t i = 0; i < args.size(); i++) { | for (size_t i = 0; i < args.size(); i++) { | ||||
| @@ -1300,6 +1313,7 @@ void PynativeExecutor::Clear(const std::string &flag) { | |||||
| curr_g_ = nullptr; | curr_g_ = nullptr; | ||||
| graph_info_map_.clear(); | graph_info_map_.clear(); | ||||
| op_id_map_.clear(); | op_id_map_.clear(); | ||||
| obj_to_forward_id_.clear(); | |||||
| std::stack<FuncGraphPtr>().swap(graph_p_); | std::stack<FuncGraphPtr>().swap(graph_p_); | ||||
| ConfigManager::GetInstance().ResetIterNum(); | ConfigManager::GetInstance().ResetIterNum(); | ||||
| } | } | ||||
| @@ -108,7 +108,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||||
| abstract::AbstractBasePtrList *args_spec_list); | abstract::AbstractBasePtrList *args_spec_list); | ||||
| void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode); | void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode); | ||||
| ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info); | ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info); | ||||
| void SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value); | |||||
| void SaveOpForwardValue(const std::string &id, const ValuePtr &value); | |||||
| void SaveForwardResult(const CNodePtr &cnode, const py::object &out); | void SaveForwardResult(const CNodePtr &cnode, const py::object &out); | ||||
| void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out); | void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out); | ||||
| @@ -138,6 +138,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||||
| std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_; | std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_; | ||||
| std::unordered_map<std::string, ValuePtr> op_forward_map_; | std::unordered_map<std::string, ValuePtr> op_forward_map_; | ||||
| std::unordered_map<std::string, size_t> op_id_map_; | std::unordered_map<std::string, size_t> op_id_map_; | ||||
| std::unordered_map<std::string, std::string> obj_to_forward_id_; | |||||
| std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_; | std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_; | ||||
| std::stack<FuncGraphPtr> graph_p_; | std::stack<FuncGraphPtr> graph_p_; | ||||
| FuncGraphPtr top_g_; | FuncGraphPtr top_g_; | ||||
| @@ -31,7 +31,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // namespace to support intermediate representation definition | // namespace to support intermediate representation definition | ||||
| CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph) | CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph) | ||||
| : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} | |||||
| : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false), output_value_(std::make_pair(nullptr, "")) {} | |||||
| // Check if CNode is an apply with the specific Primitive. | // Check if CNode is an apply with the specific Primitive. | ||||
| bool CNode::IsApply(const PrimitivePtr &value) const { | bool CNode::IsApply(const PrimitivePtr &value) const { | ||||
| @@ -232,8 +232,15 @@ class CNode : public AnfNode { | |||||
| void set_input(size_t i, const AnfNodePtr &input); | void set_input(size_t i, const AnfNodePtr &input); | ||||
| void set_inputs(const std::vector<AnfNodePtr> &inputs) { inputs_ = inputs; } | void set_inputs(const std::vector<AnfNodePtr> &inputs) { inputs_ = inputs; } | ||||
| void set_forward(const ValuePtr &forward) { forward_ = forward; } | |||||
| const ValuePtr &forward() const { return forward_; } | |||||
| void add_input_value(const ValuePtr &input_value, const std::string &id) { | |||||
| inputs_value_.push_back(std::make_pair(input_value, id)); | |||||
| } | |||||
| void clear_inputs_value() { inputs_value_.clear(); } | |||||
| void set_inputs_value(const std::vector<std::pair<ValuePtr, std::string>> &values) { inputs_value_ = values; } | |||||
| const std::vector<std::pair<ValuePtr, std::string>> &inputs_value() const { return inputs_value_; } | |||||
| void set_forward(const ValuePtr &forward, const std::string &id) { output_value_ = std::make_pair(forward, id); } | |||||
| const std::pair<ValuePtr, std::string> &forward() const { return output_value_; } | |||||
| bool stop_gradient() const { return stop_gradient_; } | bool stop_gradient() const { return stop_gradient_; } | ||||
| void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; } | void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; } | ||||
| @@ -253,7 +260,10 @@ class CNode : public AnfNode { | |||||
| VarPtr func_graph_as_var_; | VarPtr func_graph_as_var_; | ||||
| bool stop_gradient_; | bool stop_gradient_; | ||||
| bool in_forward_flag_ = false; | bool in_forward_flag_ = false; | ||||
| ValuePtr forward_ = nullptr; | |||||
| // inputs_value_ store cnode input value and id in pynative mode | |||||
| // output_value_ store cnode value and id in pynative mode | |||||
| std::vector<std::pair<ValuePtr, std::string>> inputs_value_; | |||||
| std::pair<ValuePtr, std::string> output_value_; | |||||
| }; | }; | ||||
| // ANode represents the atomic node. It's derived Parameter and ValueNode. | // ANode represents the atomic node. It's derived Parameter and ValueNode. | ||||
| @@ -88,7 +88,8 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||||
| CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target); | CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target); | ||||
| auto old_node = node->cast<CNodePtr>(); | auto old_node = node->cast<CNodePtr>(); | ||||
| new_node->set_abstract(old_node->abstract()); | new_node->set_abstract(old_node->abstract()); | ||||
| new_node->set_forward(old_node->forward()); | |||||
| new_node->set_forward(old_node->forward().first, old_node->forward().second); | |||||
| new_node->set_inputs_value(old_node->inputs_value()); | |||||
| ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); | ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); | ||||
| new_node->set_scope(scope); | new_node->set_scope(scope); | ||||
| new_node->set_kernel_info(old_node->kernel_info_ptr()); | new_node->set_kernel_info(old_node->kernel_info_ptr()); | ||||