| @@ -390,18 +390,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i | |||
| // run op | |||
| Execute(graph, false); | |||
| // get output | |||
| if (op_run_info.value != nullptr) { | |||
| std::vector<tensor::TensorPtr> pre_output_tensors; | |||
| TensorValueToTensor(op_run_info.value, &pre_output_tensors); | |||
| for (auto &pre_output : pre_output_tensors) { | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape()); | |||
| tensor->set_device_address(pre_output->device_address()); | |||
| tensor->set_sync_status(kNoNeedSync); | |||
| outputs->emplace_back(tensor); | |||
| } | |||
| } else { | |||
| UpdateOutputs(graph, outputs, input_tensors); | |||
| } | |||
| UpdateOutputs(graph, outputs, input_tensors); | |||
| RunOpMemoryClear(graph.get()); | |||
| MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; | |||
| } | |||
| @@ -335,18 +335,7 @@ void GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info | |||
| LoadInputData(kernel_graph, input_tensors); | |||
| Execute(kernel_graph); | |||
| // Fetch outputs | |||
| if (op_run_info.value != nullptr) { | |||
| std::vector<tensor::TensorPtr> pre_output_tensors; | |||
| TensorValueToTensor(op_run_info.value, &pre_output_tensors); | |||
| for (auto &pre_output : pre_output_tensors) { | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape()); | |||
| tensor->set_device_address(pre_output->device_address()); | |||
| tensor->set_sync_status(kNoNeedSync); | |||
| outputs->emplace_back(tensor); | |||
| } | |||
| } else { | |||
| UpdateOutputs(kernel_graph, outputs, input_tensors); | |||
| } | |||
| UpdateOutputs(kernel_graph, outputs, input_tensors); | |||
| RunOpClearMemory(kernel_graph.get()); | |||
| } | |||
| @@ -30,6 +30,8 @@ | |||
| #include "frontend/operator/ops.h" | |||
| #include "utils/symbolic.h" | |||
| #include "utils/ms_context.h" | |||
| #include "pipeline/jit/action.h" | |||
| #include "pipeline/jit/parse/resolve.h" | |||
| namespace mindspore { | |||
| namespace ad { | |||
| @@ -183,6 +185,7 @@ void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, | |||
| // Map a morphism. | |||
| AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { | |||
| MS_LOG(DEBUG) << "start MapMorphism:" << morph->DebugString(4); | |||
| // MapMorphism All type except CNode should already be mapped by MapObject. | |||
| if (!morph->isa<CNode>()) { | |||
| return nullptr; | |||
| @@ -238,9 +241,54 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { | |||
| // Do sens backpropagation | |||
| BackPropagate(cnode_morph, k_app, node_adjoint); | |||
| MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << "."; | |||
| MS_LOG(DEBUG) << "MapMorphism node " << morph->DebugString(4) << "."; | |||
| return node_adjoint; | |||
| } | |||
| void TensorSetAddress(const ValuePtr &value, std::map<std::string, tensor::TensorPtr> *tuple_tensors) { | |||
| MS_LOG(DEBUG) << "Start set tensor address" << value->ToString() << value->isa<tensor::Tensor>(); | |||
| if (value->isa<tensor::Tensor>()) { | |||
| auto tnode = value->cast<tensor::TensorPtr>(); | |||
| if (tuple_tensors->find(tnode->id()) != tuple_tensors->end()) { | |||
| MS_LOG(DEBUG) << "Set tensor" << tnode->device_address(); | |||
| (*tuple_tensors)[tnode->id()]->set_device_address(tnode->device_address()); | |||
| } | |||
| } | |||
| if (value->isa<ValueTuple>()) { | |||
| auto tuple = value->cast<ValueTuplePtr>(); | |||
| for (size_t i = 0; i < tuple->size(); i++) { | |||
| MS_LOG(DEBUG) << "Set tuple tensor" << (*tuple)[i]->ToString(); | |||
| TensorSetAddress((*tuple)[i], tuple_tensors); | |||
| } | |||
| } | |||
| } | |||
| ValuePtr GenNewTensorInner(const ValuePtr &value) { | |||
| std::vector<ValuePtr> value_list; | |||
| if (value->isa<tensor::Tensor>()) { | |||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||
| // return std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape()); | |||
| auto new_tensor = std::make_shared<tensor::Tensor>(*tensor); | |||
| new_tensor->set_device_address(nullptr); | |||
| return new_tensor; | |||
| } | |||
| if (value->isa<ValueTuple>()) { | |||
| auto tuple = value->cast<ValueTuplePtr>(); | |||
| for (size_t i = 0; i < tuple->size(); i++) { | |||
| value_list.push_back(GenNewTensorInner((*tuple)[i])); | |||
| } | |||
| return std::make_shared<ValueTuple>(value_list); | |||
| } | |||
| return value; | |||
| } | |||
| ValuePtr GenNewTensor(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, const ValuePtr &value) { | |||
| ValuePtr out = value; | |||
| auto ref_size = mng->node_users()[node].size(); | |||
| if (ref_size < 2) { | |||
| out = GenNewTensorInner(value); | |||
| } | |||
| return out; | |||
| } | |||
| void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) { | |||
| auto forward = cnode_morph->forward().first; | |||
| @@ -266,6 +314,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor | |||
| if (!IsValueNode<FuncGraph>(input_fg)) { | |||
| return; | |||
| } | |||
| std::map<std::string, tensor::TensorPtr> tuple_tensors; | |||
| auto equivdout = cnode_input->cast<CNodePtr>(); | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(input_fg); | |||
| auto manager = Manage({fg, func_graph}, false); | |||
| @@ -273,15 +322,10 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor | |||
| auto forward_value = forward; | |||
| if (!forward_id.empty() && ref_size > 1) { | |||
| auto inst = pynative::PynativeExecutor::GetInstance(); | |||
| inst->SaveOpForwardValue(forward_id, forward_value); | |||
| } | |||
| if (ref_size < 2) { | |||
| auto tensor = forward->cast<tensor::TensorPtr>(); | |||
| if (tensor != nullptr) { | |||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape()); | |||
| forward_value = new_tensor; | |||
| } | |||
| inst->SaveOpForwardValue(forward_id, forward_value, &tuple_tensors); | |||
| } | |||
| forward_value = GenNewTensor(manager, equivdout, forward); | |||
| MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward; | |||
| auto value_node = NewValueNode(forward_value); | |||
| value_node->set_has_new_value(true); | |||
| @@ -300,13 +344,43 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor | |||
| if (para_ref_size > 0 && input_value.first != nullptr) { | |||
| MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first; | |||
| auto inst = pynative::PynativeExecutor::GetInstance(); | |||
| inst->SaveOpForwardValue(input_value.second, input_value.first); | |||
| if (!input_value.second.empty()) { | |||
| inst->SaveOpForwardValue(input_value.second, input_value.first, &tuple_tensors); | |||
| } | |||
| auto input_value_node = NewValueNode(input_value.first); | |||
| input_value_node->set_has_new_value(true); | |||
| manager->Replace(paras[i], input_value_node); | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Start opt node" << fg->output()->DebugString(4); | |||
| auto res = std::make_shared<pipeline::Resource>(); | |||
| res->set_manager(manager); | |||
| res->set_func_graph(fg); | |||
| PynativeElimOpt(res); | |||
| auto out = fg->output()->cast<CNodePtr>(); | |||
| auto c_input = out->input(1); | |||
| if (!c_input->isa<ValueNode>()) { | |||
| return; | |||
| } | |||
| auto out_node = c_input->cast<ValueNodePtr>(); | |||
| out_node->set_value(GenNewTensor(manager, out_node, out_node->value())); | |||
| cnode_morph->clear_inputs_value(); | |||
| if (tuple_tensors.size() != 0) { | |||
| MS_LOG(DEBUG) << "Start tuple out" << fg->output()->DebugString(4); | |||
| for (auto &g : manager->func_graphs()) { | |||
| for (auto &node : g->value_nodes()) { | |||
| MS_LOG(DEBUG) << "Set Tensor addr" << node.first->ToString(); | |||
| auto vnode = node.first->cast<ValueNodePtr>()->value(); | |||
| TensorSetAddress(vnode, &tuple_tensors); | |||
| } | |||
| } | |||
| } | |||
| fg->ClearAllManagerInfo(); | |||
| func_graph->ClearAllManagerInfo(); | |||
| return; | |||
| } | |||
| @@ -59,6 +59,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate", | |||
| {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, | |||
| prim::kPrimPrintShapeType, prim::kPrimGetRefValue, prim::kPrimMirror, prim::kPrimVirtualDiv}); | |||
| pynative_eliminate_ = MakeSubstitution(std::make_shared<PynativeEliminater>(), "pynative_eliminate", IsCNodeDup); | |||
| zero_like_fill_zero_ = | |||
| MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike); | |||
| adjust_all_reduce_mul_add_ = | |||
| @@ -123,6 +123,9 @@ class OptimizeIRPassLib { | |||
| // SwitchLayer defer inline | |||
| SubstitutionPtr switch_layer_defer_inline_; | |||
| // Pynative Eliminate | |||
| SubstitutionPtr pynative_eliminate_; | |||
| }; | |||
| // the collection of irpass for resolve action | |||
| @@ -21,6 +21,7 @@ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "frontend/optimizer/optimizer_caller.h" | |||
| #include "ir/pattern_matcher.h" | |||
| @@ -31,6 +32,7 @@ | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "utils/comm_manager.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "pipeline/jit/parse/resolve.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -206,6 +208,153 @@ class DependValueElim : public OptimizerCaller { | |||
| } | |||
| }; | |||
| // {{prim:getattr, {prim::resolve, SymbolStr, C}, zeros_like}, Xy} ->Tensor(0, shape(Xy)) | |||
| // {prim:getattr, {prim::resolve, SymbolStr, zeros_like}, Xy} ->Tensor(0, shape(Xy)) | |||
| // {{prim::resolve, CommonOPS, getitem}, (tensor0, tensor1,...), 0} -> tensor0 | |||
| class PynativeEliminater : public OptimizerCaller { | |||
| bool CheckNameSpaceVNode(const AnfNodePtr &node, const std::string &str_value) { | |||
| ValueNodePtr value_node = node->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| return false; | |||
| } | |||
| return GetValueNode<parse::NameSpacePtr>(value_node)->module() == str_value; | |||
| } | |||
| bool CheckSymbolVNode(const AnfNodePtr &node, const std::string &str_value) { | |||
| ValueNodePtr value_node = node->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| return false; | |||
| } | |||
| return GetValueNode<parse::SymbolPtr>(value_node)->symbol() == str_value; | |||
| } | |||
| bool CheckStrVNode(const AnfNodePtr &node, const std::string &str_value) { | |||
| ValueNodePtr value_node = node->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| return false; | |||
| } | |||
| return GetValueNode<StringImmPtr>(value_node)->value() == str_value; | |||
| } | |||
| ValuePtr FillGetItem(const ValuePtr &value, const ValuePtr &idx) { | |||
| MS_LOG(DEBUG) << "Start FillGetItem" << value->ToString() << idx->ToString(); | |||
| if (!idx->isa<Int32Imm>()) { | |||
| MS_LOG(EXCEPTION) << "Getitem idx must int:" << idx->ToString(); | |||
| } | |||
| if (!value->isa<ValueTuple>()) { | |||
| MS_LOG(EXCEPTION) << "Getitem value must tuple:" << value->ToString(); | |||
| } | |||
| auto value_tuple = value->cast<ValueTuplePtr>(); | |||
| int idx_t = idx->cast<Int32ImmPtr>()->value(); | |||
| MS_LOG(DEBUG) << "Fill getitem" << idx_t << (*value_tuple)[idx_t]->ToString(); | |||
| return (*value_tuple)[idx_t]; | |||
| } | |||
| ValuePtr FillZero(const ValuePtr &value) { | |||
| MS_LOG(DEBUG) << "Start FillZero"; | |||
| ValuePtr out = nullptr; | |||
| if (value->isa<Int32Imm>()) { | |||
| return MakeValue(0); | |||
| } | |||
| if (value->isa<tensor::Tensor>()) { | |||
| MS_LOG(DEBUG) << "Start FillZero Tensor"; | |||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||
| tensor::TensorPtr out_t = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape()); | |||
| char *data = reinterpret_cast<char *>(out_t->data_c()); | |||
| std::fill(data, data + out_t->data().nbytes(), 0); | |||
| out = out_t; | |||
| } | |||
| std::vector<ValuePtr> value_list; | |||
| if (value->isa<ValueTuple>()) { | |||
| MS_LOG(DEBUG) << "Start FillZero Tuple" << value->ToString(); | |||
| auto value_tuple = value->cast<ValueTuplePtr>(); | |||
| for (size_t i = 0; i < value_tuple->size(); i++) { | |||
| value_list.push_back(FillZero((*value_tuple)[i])); | |||
| } | |||
| out = std::make_shared<ValueTuple>(value_list); | |||
| } | |||
| if (out == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FillZero failed:" << value->ToString(); | |||
| } | |||
| MS_LOG(DEBUG) << "Result: " << out->ToString(); | |||
| return out; | |||
| } | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); | |||
| PatternNode<AnfNodePtr> symbol_str_vnode, c_vnode, zeros_like_vnode, getitem_vnode, arg, arg1; | |||
| auto resolve = PPrimitive(prim::kPrimResolve, symbol_str_vnode, c_vnode); | |||
| auto getattr = PPrimitive(prim::kPrimGetAttr, resolve, zeros_like_vnode); | |||
| auto pattern = PCNode(getattr, arg); | |||
| if ((pattern).TryCapture(node) && | |||
| (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | |||
| CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { | |||
| auto rep = (arg).GetNode(node); | |||
| if (rep != nullptr) { | |||
| if (rep->isa<ValueNode>()) { | |||
| auto value_node = rep->cast<ValueNodePtr>(); | |||
| value_node->set_value(FillZero(value_node->value())); | |||
| MS_LOG(DEBUG) << "Zeros_like replace ok " << rep->DebugString(4); | |||
| return rep; | |||
| } | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4); | |||
| auto resolve1 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, zeros_like_vnode); | |||
| auto pattern1 = PCNode(resolve1, arg); | |||
| if ((pattern1).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | |||
| CheckSymbolVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { | |||
| auto rep = (arg).GetNode(node); | |||
| if (rep != nullptr) { | |||
| if (rep->isa<ValueNode>()) { | |||
| auto value_node = rep->cast<ValueNodePtr>(); | |||
| value_node->set_value(FillZero(value_node->value())); | |||
| MS_LOG(DEBUG) << "Zeros_like replace ok 2 " << rep->DebugString(4); | |||
| return rep; | |||
| } | |||
| } | |||
| } | |||
| // resolve(CommonOPS, getitem)((tensors), 3) | |||
| auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode); | |||
| auto pattern2 = PCNode(resolve2, arg, arg1); | |||
| if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") && | |||
| CheckSymbolVNode(getitem_vnode.GetNode(node), "getitem"))) { | |||
| auto rep = (arg).GetNode(node); | |||
| if (rep != nullptr) { | |||
| if (rep->isa<ValueNode>()) { | |||
| MS_LOG(DEBUG) << "Rep is " << rep->DebugString(4); | |||
| ValueNodePtr new_node; | |||
| auto value_node = rep->cast<ValueNodePtr>(); | |||
| auto rep1 = (arg1).GetNode(node); | |||
| if (rep1 != nullptr) { | |||
| if (rep1->isa<ValueNode>()) { | |||
| auto idx = rep1->cast<ValueNodePtr>(); | |||
| if (!value_node->value()->isa<ValueTuple>()) { | |||
| return nullptr; | |||
| } | |||
| new_node = NewValueNode(FillGetItem(value_node->value(), idx->value())); | |||
| new_node->set_has_new_value(value_node->has_new_value()); | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Fill getitem replace ok " << new_node->DebugString(4); | |||
| return new_node; | |||
| } | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "End Replace " << node->DebugString(4); | |||
| return nullptr; | |||
| } | |||
| }; | |||
| class AllReduceConstElim : public OptimizerCaller { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| @@ -185,9 +185,11 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; | |||
| auto fg_name = | |||
| "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; | |||
| func_graph->DumpFuncGraph(fg_name); | |||
| DumpIR(fg_name + ".ir", func_graph); | |||
| ExportIR(fg_name + ".dat", "", func_graph); | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| func_graph->DumpFuncGraph(fg_name); | |||
| ExportIR(fg_name + ".dat", "", func_graph); | |||
| } | |||
| MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; | |||
| } | |||
| } | |||
| @@ -314,6 +314,16 @@ bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPa | |||
| bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); } | |||
| bool PynativeElimOpt(const ResourcePtr &res) { | |||
| if (res->manager() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "PynativeElimOpt error, manager is null."; | |||
| } | |||
| if (res->func_graph() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "PynativeElimOpt error, graph is null."; | |||
| } | |||
| return PynativeOptPass(res); | |||
| } | |||
| static bool IsCtrlSink() { | |||
| auto ms_ctx = MsContext::GetInstance(); | |||
| if (ms_ctx->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode) { | |||
| @@ -36,6 +36,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res); | |||
| bool GeOptimizeAction(const ResourcePtr &res); | |||
| bool VmOptimizeAction(const ResourcePtr &res); | |||
| bool PynativeOptimizeAction(const ResourcePtr &res); | |||
| bool PynativeElimOpt(const ResourcePtr &res); | |||
| bool TaskEmitAction(const ResourcePtr &res); | |||
| bool ExecuteAction(const ResourcePtr &res); | |||
| bool StartPSWorkerAction(const ResourcePtr &res); | |||
| @@ -215,6 +215,17 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| return map; | |||
| } | |||
| OptPassGroupMap GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({ | |||
| irpass.pynative_eliminate_, | |||
| }); | |||
| OptPassGroupMap map({ | |||
| {"pynative_eliminate", pynative_eliminate}, | |||
| }); | |||
| return map; | |||
| } | |||
| OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig interface_fusion = opt::OptPassConfig({ | |||
| irpass.mark_interface_fusion_, | |||
| @@ -422,6 +433,16 @@ bool InferenceOptPreparePass(const ResourcePtr &res) { | |||
| return true; | |||
| } | |||
| bool PynativeOptPass(const ResourcePtr &res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| opt::irpass::OptimizeIRPassLib irpass; | |||
| auto pynative_opt = GetOptPassesPynativeElim(irpass); | |||
| auto pynative_opt_opt = opt::Optimizer::MakeOptimizer("pynative_opt", res, pynative_opt); | |||
| (void)pynative_opt_opt->step(func_graph, false); | |||
| return true; | |||
| } | |||
| std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | |||
| {"opt_a", OptPassAGroup}, | |||
| {"clean_after_opta", CleanAfterOptAPass}, | |||
| @@ -38,6 +38,7 @@ bool ConvertPrepareAdapt(const ResourcePtr &res); | |||
| bool AddControlDependPass(const ResourcePtr &res); | |||
| bool InferenceOptPreparePass(const ResourcePtr &res); | |||
| void ReclaimOptimizer(); | |||
| bool PynativeOptPass(const ResourcePtr &res); | |||
| } // namespace pipeline | |||
| } // namespace mindspore | |||
| @@ -206,7 +206,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto out = ToAbstract(value_node->value(), conf->context(), conf); | |||
| if (value_node->has_new_value()) { | |||
| if (value_node->has_new_value() && out->isa<AbstractTensor>()) { | |||
| out = out->Broaden(); | |||
| } | |||
| return out; | |||
| @@ -59,6 +59,8 @@ | |||
| #include "pipeline/pynative/pynative_execute_ge.h" | |||
| #endif | |||
| #include "debug/anf_ir_dump.h" | |||
| using mindspore::tensor::TensorPy; | |||
| const char SINGLE_OP_GRAPH[] = "single_op_graph"; | |||
| @@ -780,19 +782,79 @@ void PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::ob | |||
| set_pyobj(curr_g_, obj_id); | |||
| } | |||
| void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr &value) { | |||
| auto iter = op_forward_map_.find(id); | |||
| if (iter != op_forward_map_.end()) { | |||
| void GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map) { | |||
| if (t_map == nullptr) { | |||
| return; | |||
| } | |||
| auto tuple_info_iter = obj_to_forward_id_tuple_info_.find(id); | |||
| ValuePtr temp_value = value; | |||
| if (tuple_info_iter != obj_to_forward_id_tuple_info_.end()) { | |||
| temp_value = tuple_info_iter->second; | |||
| for (size_t i = 0; i < tuple->size(); i++) { | |||
| ValuePtr tuple_i = (*tuple)[i]; | |||
| if (tuple_i->isa<tensor::Tensor>()) { | |||
| auto t = tuple_i->cast<tensor::TensorPtr>(); | |||
| (*t_map)[t->id()] = t; | |||
| } else if (tuple_i->isa<ValueTuple>()) { | |||
| GenTupleMap(tuple_i->cast<ValueTuplePtr>(), t_map); | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "End GenTupleMap" << tuple->ToString(); | |||
| } | |||
| ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple) { | |||
| std::vector<ValuePtr> value_list; | |||
| for (size_t i = 0; i < tuple->size(); i++) { | |||
| ValuePtr tuple_i = (*tuple)[i]; | |||
| if (tuple_i->isa<tensor::Tensor>()) { | |||
| auto t = tuple_i->cast<tensor::TensorPtr>(); | |||
| auto new_tensor = std::make_shared<tensor::Tensor>(*t); | |||
| new_tensor->set_device_address(nullptr); | |||
| value_list.push_back(new_tensor); | |||
| } else if (tuple_i->isa<ValueTuple>()) { | |||
| value_list.push_back(CleanTupleAddr(tuple_i->cast<ValueTuplePtr>())); | |||
| } else { | |||
| MS_LOG(DEBUG) << "in value" << tuple_i->ToString(); | |||
| value_list.push_back(tuple_i); | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "End CleanTupleAddr"; | |||
| return std::make_shared<ValueTuple>(value_list); | |||
| } | |||
| void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr &value, | |||
| std::map<std::string, tensor::TensorPtr> *t_map) { | |||
| if (op_forward_map_.find(id) != op_forward_map_.end()) { | |||
| if (op_forward_map_[id]->isa<ValueTuple>()) { | |||
| // for one op have multi outputs but save only one tensor | |||
| if (value->isa<tensor::Tensor>()) { | |||
| auto tuple = op_forward_map_[id]->cast<ValueTuplePtr>(); | |||
| auto value_t = value->cast<tensor::TensorPtr>(); | |||
| for (size_t i = 0; i < tuple->size(); i++) { | |||
| if ((*tuple)[i]->isa<tensor::Tensor>()) { | |||
| auto tuple_t = (*tuple)[i]->cast<tensor::TensorPtr>(); | |||
| if (value_t->id() == tuple_t->id()) { | |||
| tuple_t->set_device_address(value_t->device_address()); | |||
| MS_LOG(DEBUG) << "After Saveop " << tuple_t->ToString(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (value->isa<ValueTuple>() && t_map != nullptr) { | |||
| GenTupleMap(op_forward_map_[id]->cast<ValueTuplePtr>(), t_map); | |||
| } | |||
| MS_LOG(DEBUG) << "Save op forward value: " | |||
| << "(" << id << "), " << op_forward_map_[id]->ToString(); | |||
| return; | |||
| } | |||
| if (value->isa<ValueTuple>() && t_map == nullptr) { | |||
| // make cnode gen all tuple node and set device_address be null | |||
| op_forward_map_[id] = CleanTupleAddr(value->cast<ValueTuplePtr>()); | |||
| } else { | |||
| op_forward_map_[id] = value; | |||
| } | |||
| op_forward_map_[id] = temp_value; | |||
| MS_LOG(DEBUG) << "Save op forward value: " | |||
| << "(" << id << "), " << temp_value; | |||
| << "(" << id << "), " << value->ToString(); | |||
| } | |||
| void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) { | |||
| @@ -828,7 +890,7 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CN | |||
| auto tuple_item_id = GetId(tuple_item[i]); | |||
| obj_to_forward_id_[tuple_item_id] = op_id; | |||
| } | |||
| obj_to_forward_id_tuple_info_[op_id] = value; | |||
| SaveOpForwardValue(op_id, value, nullptr); | |||
| } | |||
| obj_to_forward_id_[out_id] = op_id; | |||
| } | |||
| @@ -840,12 +902,24 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { | |||
| if (out.second.size() == 1 && out.second[0] == -1) { | |||
| return out.first; | |||
| } | |||
| auto node = out.first; | |||
| CNodePtr node = out.first->cast<CNodePtr>(); | |||
| MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString(); | |||
| auto abs = node->abstract(); | |||
| ValuePtr out_obj = nullptr; | |||
| if (node->forward().first != nullptr) { | |||
| out_obj = node->forward().first; | |||
| } else { | |||
| out_obj = PyAttrValue(obj); | |||
| } | |||
| for (auto &idx : out.second) { | |||
| std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)}; | |||
| node = curr_g_->NewCNode(tuple_get_item_inputs); | |||
| if (out_obj->isa<ValueTuple>()) { | |||
| node->add_input_value(out_obj, ""); | |||
| node->add_input_value(MakeValue(idx), ""); | |||
| out_obj = (*out_obj->cast<ValueTuplePtr>())[idx]; | |||
| node->set_forward(out_obj, ""); | |||
| } | |||
| if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) { | |||
| auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[idx]; | |||
| MS_LOG(DEBUG) << "set tuple getitem abs" << prim_abs->ToString(); | |||
| @@ -856,7 +930,6 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { | |||
| node_abs_map_[id] = node->abstract(); | |||
| } | |||
| MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6); | |||
| node->cast<CNodePtr>()->set_forward(PyAttrValue(obj), ""); | |||
| return node; | |||
| } | |||
| @@ -1306,7 +1379,13 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje | |||
| } | |||
| set_obj_node_map(graph_prev, GetId(out), out_cnode); | |||
| } else { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | |||
| DumpIR("before_resolve.ir", newfg); | |||
| } | |||
| parse::ResolveFuncGraph(newfg, resource_); | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | |||
| DumpIR("after_resolve.ir", newfg); | |||
| } | |||
| resource_->set_func_graph(newfg); | |||
| Popp(); | |||
| } | |||
| @@ -1426,7 +1505,13 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje | |||
| MS_LOG(EXCEPTION) << "Could not find top graph by cellid: " << forward_cell_id; | |||
| } | |||
| top_g_ = cell_graph_map_[forward_cell_id]; | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | |||
| DumpIR("before_grad.ir", resource_->func_graph()); | |||
| } | |||
| auto g = GradGraph(resource_->func_graph(), grad, w_args, size); | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | |||
| DumpIR("after_grad.ir", g); | |||
| } | |||
| resource_->set_func_graph(g); | |||
| resource_->manager()->KeepRoots({g}); | |||
| @@ -25,6 +25,7 @@ | |||
| #include <mutex> | |||
| #include <stack> | |||
| #include <set> | |||
| #include <map> | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/numpy.h" | |||
| @@ -121,7 +122,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||
| abstract::AbstractBasePtrList *args_spec_list); | |||
| void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode); | |||
| ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info); | |||
| void SaveOpForwardValue(const std::string &id, const ValuePtr &value); | |||
| void SaveOpForwardValue(const std::string &id, const ValuePtr &value, | |||
| std::map<std::string, tensor::TensorPtr> *t_map); | |||
| void SaveForwardResult(const CNodePtr &cnode, const py::object &out); | |||
| void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out); | |||
| @@ -154,7 +156,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||
| std::unordered_map<std::string, ValuePtr> op_forward_map_; | |||
| std::unordered_map<std::string, size_t> op_id_map_; | |||
| std::unordered_map<std::string, std::string> obj_to_forward_id_; | |||
| std::unordered_map<std::string, ValuePtr> obj_to_forward_id_tuple_info_; | |||
| std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_; | |||
| std::unordered_map<std::string, FuncGraphPtr> df_builder_map_; | |||
| // the stack that records the context of graph created, the bottom is the top graph | |||
| @@ -85,6 +85,8 @@ void KernelRuntime::RunOpAssignMemory(const ValuePtr &pre_output_value, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(mem_manager_); | |||
| mem_manager_->ResetDynamicMemory(); | |||
| RunOpAssignInputMemory(input_tensors, graph); | |||
| AssignStaticMemoryValueNode(graph); | |||
| RunOpAssignOutputNodeMemory(pre_output_value, graph); | |||
| @@ -268,7 +270,8 @@ void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value | |||
| MS_EXCEPTION_IF_NULL(real_output_cnode); | |||
| MS_EXCEPTION_IF_NULL(pre_output_tensors[i]); | |||
| if (pre_output_tensors[i]->device_address() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "The address of pre output tensor [" << i << "] is a nullptr!"; | |||
| MS_LOG(INFO) << "The address of pre output tensor [" << i << "] is a nullptr!"; | |||
| continue; | |||
| } | |||
| if (opt::IsNopNode(real_output_cnode)) { | |||
| if (real_output_cnode->inputs().size() < 2) { | |||
| @@ -155,7 +155,7 @@ def test_softmaxloss_grad(): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.weight = Parameter(Tensor(np.ones([64, 10])), name="weight") | |||
| self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight") | |||
| self.bias = Parameter(Tensor(np.ones([10]).astype(np.float32)), name="bias") | |||
| self.fc = P.MatMul() | |||
| self.fc2 = nn.Dense(10, 10) | |||
| @@ -175,7 +175,7 @@ def test_softmaxloss_grad(): | |||
| net = GradWrap(NetWithLossClass(Net())) | |||
| predict = Tensor(np.ones([1, 64])) | |||
| predict = Tensor(np.ones([1, 64]).astype(np.float32)) | |||
| label = Tensor(np.zeros([1, 10]).astype(np.float32)) | |||
| print("pynative run") | |||
| out = net(predict, label) | |||