| @@ -737,12 +737,33 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod | |||
| if (!anf_node->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode." << trace::DumpSourceLines(anf_node); | |||
| } | |||
| auto kernel_info = anf_node->kernel_info(); | |||
| if (kernel_info) { | |||
| auto runtime_cache = kernel_info->runtime_cache(); | |||
| MS_EXCEPTION_IF_NULL(runtime_cache); | |||
| if (runtime_cache->is_valid()) { | |||
| auto output = runtime_cache->get_prev_node_output(input_idx); | |||
| if (output.first != nullptr) { | |||
| return output; | |||
| } | |||
| } | |||
| } | |||
| KernelWithIndex res; | |||
| if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) { | |||
| return VisitKernelWithReturnType(anf_node, 0, skip_nop_node); | |||
| res = VisitKernelWithReturnType(anf_node, 0, skip_nop_node); | |||
| } else { | |||
| auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| res = VisitKernelWithReturnType(input_node, 0, skip_nop_node); | |||
| } | |||
| auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| return VisitKernelWithReturnType(input_node, 0, skip_nop_node); | |||
| if (kernel_info) { | |||
| auto runtime_cache = kernel_info->runtime_cache(); | |||
| MS_EXCEPTION_IF_NULL(runtime_cache); | |||
| if (runtime_cache->is_valid()) { | |||
| runtime_cache->set_prev_node_output(input_idx, res); | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { | |||
| @@ -2180,9 +2201,9 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::map<uint32_t, te | |||
| for (size_t i = 0; i < input_size; ++i) { | |||
| auto input_with_index = AnfAlgo::GetPrevNodeOutput(node, i); | |||
| auto real_input = input_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(real_input); | |||
| auto cnode_input = node->input(i + 1); | |||
| MS_EXCEPTION_IF_NULL(cnode_input); | |||
| MS_EXCEPTION_IF_NULL(real_input); | |||
| if (depend_tensors != nullptr) { | |||
| auto iter_tensor = depend_tensors->find(i); | |||
| if (iter_tensor != depend_tensors->end()) { | |||
| @@ -2202,28 +2223,32 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::map<uint32_t, te | |||
| } | |||
| } | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) { | |||
| auto base_shape = real_input->Shape(); | |||
| if (!base_shape->isa<abstract::TupleShape>()) { | |||
| MS_LOG(EXCEPTION) << "Node:" << node->DebugString() | |||
| << " input is a tuple_get_item but real input node shape is not a TupleShape. trace: " | |||
| << trace::DumpSourceLines(real_input); | |||
| } | |||
| auto abs = real_input->abstract()->cast<abstract::AbstractTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(abs); | |||
| auto tuple_get_item_indexk = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>()); | |||
| auto abs_i = abs->elements()[tuple_get_item_indexk]; | |||
| (void)args_spec_list.emplace_back(abs_i); | |||
| } else if (cnode_input->isa<CNode>() && AnfAlgo::GetCNodeName(cnode_input) == prim::kPrimReshape->name()) { | |||
| (void)args_spec_list.emplace_back(cnode_input->abstract()); | |||
| } else { | |||
| (void)args_spec_list.emplace_back(real_input->abstract()); | |||
| } | |||
| AddArgList(&args_spec_list, cnode_input, real_input, i); | |||
| } | |||
| auto eval_result = opt::CppInferShape(primitive, args_spec_list); | |||
| node->set_abstract(eval_result); | |||
| } | |||
| void AnfRuntimeAlgorithm::AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input, | |||
| const AnfNodePtr &real_input, size_t index) { | |||
| if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) { | |||
| auto base_shape = real_input->Shape(); | |||
| if (!base_shape->isa<abstract::TupleShape>()) { | |||
| MS_LOG(EXCEPTION) << "Node input is a tuple_get_item but real input node shape is not a TupleShape. trace: " | |||
| << trace::DumpSourceLines(real_input); | |||
| } | |||
| auto abs = real_input->abstract()->cast<abstract::AbstractTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(abs); | |||
| auto tuple_get_item_indexk = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>()); | |||
| auto abs_i = abs->elements()[tuple_get_item_indexk]; | |||
| (void)args_spec_list->emplace_back(abs_i); | |||
| } else if (cnode_input->isa<CNode>() && AnfAlgo::GetCNodeName(cnode_input) == prim::kPrimReshape->name()) { | |||
| (void)args_spec_list->emplace_back(cnode_input->abstract()); | |||
| } else { | |||
| (void)args_spec_list->emplace_back(real_input->abstract()); | |||
| } | |||
| } | |||
| void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> &root_graph) { | |||
| auto return_node = root_graph->get_return(); | |||
| MS_EXCEPTION_IF_NULL(return_node); | |||
| @@ -296,6 +296,8 @@ class AnfRuntimeAlgorithm { | |||
| static std::vector<int64_t> GetOutputMinShape(const AnfNodePtr &anf_node, size_t index); | |||
| static bool IsNodeDynamicShape(const AnfNodePtr &node); | |||
| static void InferShape(const CNodePtr &node, std::map<uint32_t, tensor::TensorPtr> *depend_tensors = nullptr); | |||
| static void AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input, | |||
| const AnfNodePtr &real_input, size_t index); | |||
| static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); | |||
| static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); | |||
| // Find real input nodes. | |||
| @@ -1001,20 +1001,21 @@ void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5 | |||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format, | |||
| const int64_t groups, const std::vector<int64_t> &input_hidden_size) { | |||
| using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>; | |||
| const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, | |||
| {kOpFormat_NHWC, NhwcDeviceShape}, | |||
| {kOpFormat_HWCN, HwchDeviceShape}, | |||
| {kOpFormat_FRAC_Z, FracZDeviceShape}, | |||
| {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, | |||
| {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, | |||
| {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, | |||
| {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}, | |||
| {kOpFormat_NCDHW, NcdhwDeviceShape}, | |||
| {kOpFormat_ChannelLast, ChannelLastDeviceShape}, | |||
| {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape}, | |||
| {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}, | |||
| {kOpFormat_FRAC_NZ, FracNZDeviceShape}, | |||
| {kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceShape}}; | |||
| static const std::map<std::string, DeviceShapeTransfer> device_shape_map{ | |||
| {kOpFormat_NCHW, NchwDeviceShape}, | |||
| {kOpFormat_NHWC, NhwcDeviceShape}, | |||
| {kOpFormat_HWCN, HwchDeviceShape}, | |||
| {kOpFormat_FRAC_Z, FracZDeviceShape}, | |||
| {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, | |||
| {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, | |||
| {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, | |||
| {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}, | |||
| {kOpFormat_NCDHW, NcdhwDeviceShape}, | |||
| {kOpFormat_ChannelLast, ChannelLastDeviceShape}, | |||
| {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape}, | |||
| {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}, | |||
| {kOpFormat_FRAC_NZ, FracNZDeviceShape}, | |||
| {kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceShape}}; | |||
| if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { | |||
| return shape; | |||
| @@ -47,61 +47,66 @@ void DynamicKernel::Initialize() { | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Have depends"; | |||
| (void)std::transform(ret.begin(), ret.end(), std::back_inserter(depend_list_), | |||
| (void)std::transform(ret.begin(), ret.end(), std::inserter(depend_list_, depend_list_.begin()), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| MS_LOG(INFO) << "Init End"; | |||
| } | |||
| int DynamicKernel::GetKernelType() const { return AnfAlgo::GetKernelType(cnode_ptr_.lock()); } | |||
| void DynamicKernel::RebuildDependTensor() { | |||
| depend_tensor_map_.clear(); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| for (auto depend : depend_list_) { | |||
| auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, depend); | |||
| bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT); | |||
| auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, depend, skip_nop_node); | |||
| std::vector<int64_t> shapes = trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second); | |||
| auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second); | |||
| auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes); | |||
| MS_EXCEPTION_IF_NULL(out_tensor); | |||
| // The second parameter must be false, otherwise the device address cannot be released and allocated, and the | |||
| // address size will be wrong in the dynamic shape scenario. | |||
| out_tensor->set_device_address(output_addr, false); | |||
| auto ret = depend_tensor_map_.try_emplace(depend, out_tensor); | |||
| if (!ret.second) { | |||
| MS_LOG(EXCEPTION) << "Insert map failed"; | |||
| } | |||
| } | |||
| } | |||
| void DynamicKernel::InferShape() { | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_LOG(INFO) << "InferShape start, node:" << cnode->fullname_with_scope(); | |||
| InferShapeRecursive(); | |||
| depend_tensor_map_.clear(); | |||
| auto inputs = cnode->inputs(); | |||
| if (inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Invalid inputs"; | |||
| } | |||
| // rebuild depend tensor map for gpu dynamic memory allocation. | |||
| RebuildDependTensor(); | |||
| AnfAlgo::InferShape(cnode, &depend_tensor_map_); | |||
| } | |||
| void DynamicKernel::InferShapeRecursive() { | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| AbstractBasePtrList args_spec_list; | |||
| auto primitive = GetValueNode<PrimitivePtr>(inputs[0]); | |||
| auto input_size = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t i = 0; i < input_size; i++) { | |||
| auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i); | |||
| auto input_node = input_node_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| InferShapeForNopNode(&input_node); | |||
| auto real_input = input_node_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(real_input); | |||
| auto cnode_input = cnode->input(i + 1); | |||
| MS_EXCEPTION_IF_NULL(cnode_input); | |||
| InferShapeForNopNode(&real_input); | |||
| if (depend_list_.find(i) != depend_list_.end()) { | |||
| auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i); | |||
| bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT); | |||
| auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, i, skip_nop_node); | |||
| std::vector<int64_t> shapes = | |||
| trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second); | |||
| auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second); | |||
| auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes); | |||
| MS_EXCEPTION_IF_NULL(out_tensor); | |||
| // The second parameter must be false, otherwise the device address cannot be released and allocated, and the | |||
| // address size will be wrong in the dynamic shape scenario. | |||
| out_tensor->set_device_address(output_addr, false); | |||
| auto ret = depend_tensor_map_.try_emplace(i, out_tensor); | |||
| if (!ret.second) { | |||
| MS_LOG(EXCEPTION) << "Insert map failed"; | |||
| } | |||
| out_tensor->data_sync(); | |||
| auto real_abs = real_input->abstract(); | |||
| if (real_abs->isa<abstract::AbstractTensor>()) { | |||
| real_input->abstract()->set_value(out_tensor); | |||
| } else if (real_abs->isa<abstract::AbstractTuple>()) { | |||
| auto tuple_get_item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>()); | |||
| auto abstract_tuple = real_abs->cast<abstract::AbstractTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(abstract_tuple); | |||
| auto tuple_elements = abstract_tuple->elements()[tuple_get_item_index]; | |||
| tuple_elements->set_value(out_tensor); | |||
| } | |||
| } | |||
| AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i); | |||
| } | |||
| auto eval_result = opt::CppInferShape(primitive, args_spec_list); | |||
| cnode->set_abstract(eval_result); | |||
| } | |||
| void DynamicKernel::InferShapeForNopNode(AnfNodePtr *input_node) { | |||
| @@ -21,6 +21,7 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <set> | |||
| #include "ir/anf.h" | |||
| #include "ir/tensor.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| @@ -46,8 +47,6 @@ class DynamicKernel { | |||
| [[nodiscard]] int GetKernelType() const; | |||
| protected: | |||
| void RebuildDependTensor(); | |||
| void InferShapeRecursive(); | |||
| static void InferShapeForNopNode(AnfNodePtr *input_node); | |||
| void *stream_; | |||
| @@ -55,7 +54,7 @@ class DynamicKernel { | |||
| bool is_dynamic_shape_; | |||
| bool is_input_dynamic_shape_; | |||
| bool is_output_dynamic_shape_; | |||
| std::vector<uint32_t> depend_list_; | |||
| std::set<uint32_t> depend_list_; | |||
| std::map<uint32_t, tensor::TensorPtr> depend_tensor_map_; | |||
| }; | |||
| using DynamicKernelPtr = std::shared_ptr<DynamicKernel>; | |||
| @@ -477,6 +477,7 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic | |||
| (void)mindspore::RDR::RecordGraphExecOrder(SubModuleId::SM_SESSION, exec_order_name, kernels); | |||
| #endif | |||
| device_context->EnableRuntimeCache(graph); | |||
| session_->DumpGraph(graph); | |||
| return graph->graph_id(); | |||
| } | |||
| @@ -156,6 +156,20 @@ class DeviceContext { | |||
| // Dump all graphs. | |||
| virtual void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const {} | |||
| void EnableRuntimeCache(const KernelGraphPtr &graph) const { | |||
| auto node_list = graph->TopoSort(graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| auto kernel_info = node->kernel_info(); | |||
| if (!kernel_info) { | |||
| continue; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto runtime_cache = kernel_info->runtime_cache(); | |||
| MS_EXCEPTION_IF_NULL(runtime_cache); | |||
| runtime_cache->set_valid(); | |||
| } | |||
| } | |||
| protected: | |||
| DeviceContextKey device_context_key_; | |||
| @@ -370,7 +370,7 @@ void GPUDeviceContext::UpdateDynamicShape(const CNodePtr &kernel) const { | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| bool is_pynative_infer = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER); | |||
| bool is_pynative_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode; | |||
| std::vector<int64_t> dynamic_shape_depends = abstract::GetDependsFormMap(kernel); | |||
| auto dynamic_shape_depends = abstract::GetDependsFormMap(kernel); | |||
| if ((is_pynative_infer || is_pynative_mode) && dynamic_shape_depends.empty()) { | |||
| return; | |||
| } | |||
| @@ -19,6 +19,7 @@ | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include <string> | |||
| #include <vector> | |||
| #include <set> | |||
| #include "ops/exp.h" | |||
| #include "ops/log.h" | |||
| #include "ops/reciprocal.h" | |||
| @@ -39,9 +40,9 @@ | |||
| #include "ops/grad/slice_grad.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||
| using ShapeVec = std::vector<int64_t>; | |||
| using PrimShapeDependMap = mindspore::HashMap<std::string, ShapeVec>; | |||
| std::set<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||
| using ShapeSet = std::set<int64_t>; | |||
| using PrimShapeDependMap = mindspore::HashMap<std::string, ShapeSet>; | |||
| static const auto &kOneHot = prim::kPrimOneHot->name(); | |||
| static const auto &kDropoutGenMask = prim::kPrimDropoutGenMask->name(); | |||
| static const auto &kTranspose = prim::kPrimTranspose->name(); | |||
| @@ -63,24 +64,24 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||
| static const auto &kReshape = prim::kPrimReshape->name(); | |||
| static const auto &kDynamicReshape = prim::kPrimDynamicReshape->name(); | |||
| // Common dynamic shape depends. | |||
| static const PrimShapeDependMap dynamic_shape_depends{{kUnsortedSegmentSum, ShapeVec{2}}, | |||
| {kUnsortedSegmentMin, ShapeVec{2}}, | |||
| {kUnsortedSegmentMax, ShapeVec{2}}, | |||
| {kGather, ShapeVec{2}}, | |||
| {kGatherV2, ShapeVec{2}}, | |||
| {kRange, ShapeVec{0, 1, 2}}, | |||
| {kConv2DBackpropFilter, ShapeVec{2}}, | |||
| {kConv2DBackpropInput, ShapeVec{2}}, | |||
| {kOneHot, ShapeVec{1, 3}}, | |||
| {kDropoutGenMask, ShapeVec{0}}, | |||
| {kStridedSlice, ShapeVec{1, 2, 3}}, | |||
| {kStridedSliceGrad, ShapeVec{1, 2, 3, 4}}, | |||
| {kTile, ShapeVec{1}}, | |||
| {kReshape, ShapeVec{1}}, | |||
| {kDynamicReshape, ShapeVec{1}}, | |||
| {kSlice, ShapeVec{1, 2}}, | |||
| {kSliceGrad, ShapeVec{2, 3}}, | |||
| {kDynamicBroadcastTo, ShapeVec{1}}}; | |||
| static const PrimShapeDependMap dynamic_shape_depends{{kUnsortedSegmentSum, ShapeSet{2}}, | |||
| {kUnsortedSegmentMin, ShapeSet{2}}, | |||
| {kUnsortedSegmentMax, ShapeSet{2}}, | |||
| {kGather, ShapeSet{2}}, | |||
| {kGatherV2, ShapeSet{2}}, | |||
| {kRange, ShapeSet{0, 1, 2}}, | |||
| {kConv2DBackpropFilter, ShapeSet{2}}, | |||
| {kConv2DBackpropInput, ShapeSet{2}}, | |||
| {kOneHot, ShapeSet{1, 3}}, | |||
| {kDropoutGenMask, ShapeSet{0}}, | |||
| {kStridedSlice, ShapeSet{1, 2, 3}}, | |||
| {kStridedSliceGrad, ShapeSet{1, 2, 3, 4}}, | |||
| {kTile, ShapeSet{1}}, | |||
| {kReshape, ShapeSet{1}}, | |||
| {kDynamicReshape, ShapeSet{1}}, | |||
| {kSlice, ShapeSet{1, 2}}, | |||
| {kSliceGrad, ShapeSet{2, 3}}, | |||
| {kDynamicBroadcastTo, ShapeSet{1}}}; | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->inputs().empty()) { | |||
| @@ -101,9 +102,9 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||
| auto iter = dynamic_shape_depends.find(prim_name); | |||
| if (iter != dynamic_shape_depends.end()) { | |||
| int64_t cnode_input_size = SizeToLong(cnode->inputs().size()); | |||
| std::vector<int64_t> res; | |||
| ShapeSet res; | |||
| auto ori = iter->second; | |||
| (void)std::copy_if(ori.begin(), ori.end(), std::back_inserter(res), | |||
| (void)std::copy_if(ori.begin(), ori.end(), std::inserter(res, res.begin()), | |||
| [&](auto idx) { return idx < cnode_input_size - 1; }); | |||
| return res; | |||
| } | |||
| @@ -19,6 +19,7 @@ | |||
| #define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <memory> | |||
| #include "utils/hash_map.h" | |||
| #include "ir/primitive.h" | |||
| @@ -50,7 +51,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap(); | |||
| StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive); | |||
| std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode); | |||
| std::set<int64_t> GetDependsFormMap(const CNodePtr &cnode); | |||
| void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg); | |||
| @@ -622,14 +622,36 @@ std::string GetOriginNodeTarget(const AnfNodePtr &node) { | |||
| } | |||
| std::string GetCNodeTarget(const AnfNodePtr &node) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| auto target = GetOriginNodeTarget(node); | |||
| if (target != kTargetUnDefined) { | |||
| return target; | |||
| auto kernel_info = node->kernel_info(); | |||
| if (kernel_info != nullptr) { | |||
| auto runtime_cache = kernel_info->runtime_cache(); | |||
| MS_EXCEPTION_IF_NULL(runtime_cache); | |||
| if (runtime_cache->is_valid()) { | |||
| auto tmp_target = runtime_cache->device_target(); | |||
| if (!tmp_target.empty()) { | |||
| return tmp_target; | |||
| } | |||
| } | |||
| } | |||
| return default_target; | |||
| std::string target; | |||
| auto ori_target = GetOriginNodeTarget(node); | |||
| if (ori_target != kTargetUnDefined) { | |||
| target = ori_target; | |||
| } else { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| } | |||
| if (kernel_info != nullptr) { | |||
| auto runtime_cache = kernel_info->runtime_cache(); | |||
| MS_EXCEPTION_IF_NULL(runtime_cache); | |||
| if (runtime_cache->is_valid()) { | |||
| runtime_cache->set_device_target(target); | |||
| } | |||
| } | |||
| return target; | |||
| } | |||
| bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { | |||
| @@ -18,6 +18,10 @@ | |||
| #define MINDSPORE_CORE_IR_KERNEL_INFO_DEV_H_ | |||
| #include <memory> | |||
| #include <map> | |||
| #include <utility> | |||
| #include <string> | |||
| #include "utils/info.h" | |||
| namespace mindspore { | |||
| enum Axis : int { | |||
| @@ -26,11 +30,51 @@ enum Axis : int { | |||
| H, | |||
| W, | |||
| }; | |||
| // Cache some runtime information which not be changed. | |||
| class RuntimeCache { | |||
| public: | |||
| std::pair<AnfNodePtr, size_t> get_prev_node_output(size_t index) { | |||
| auto it = prev_node_output_map_.find(index); | |||
| if (it != prev_node_output_map_.end()) { | |||
| return it->second; | |||
| } else { | |||
| return std::pair<AnfNodePtr, size_t>(); | |||
| } | |||
| } | |||
| void set_prev_node_output(size_t index, std::pair<AnfNodePtr, size_t> output) { | |||
| auto pr = std::make_pair(index, output); | |||
| (void)prev_node_output_map_.insert(pr); | |||
| } | |||
| std::string device_target() { return device_target_; } | |||
| void set_device_target(const std::string &target) { device_target_ = target; } | |||
| bool is_valid() { return is_valid_; } | |||
| void set_valid() { is_valid_ = true; } | |||
| void set_output_tensor_num(const ssize_t output_tensor_num) { output_tensor_num_ = output_tensor_num; } | |||
| ssize_t output_tensor_num() const { return output_tensor_num_; } | |||
| void set_real_kernel(enum CacheBool b) { is_real_kernel_ = b; } | |||
| enum CacheBool is_real_kernel() { return is_real_kernel_; } | |||
| private: | |||
| bool is_valid_{false}; | |||
| std::map<size_t, std::pair<AnfNodePtr, size_t>> prev_node_output_map_; | |||
| std::string device_target_; | |||
| ssize_t output_tensor_num_ = -1; | |||
| enum CacheBool is_real_kernel_ = CacheBool::UNCACHED; | |||
| }; | |||
| // Interface for device kernel program information. | |||
| class KernelInfoDevice { | |||
| public: | |||
| // If kernel program was built and build info is set. | |||
| virtual bool has_build_info() const = 0; | |||
| RuntimeCache *runtime_cache() { return &runtime_cache_; } | |||
| private: | |||
| RuntimeCache runtime_cache_; | |||
| }; | |||
| using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>; | |||
| } // namespace mindspore | |||
| @@ -108,7 +108,28 @@ bool AnfUtils::IsRealKernel(const AnfNodePtr &node) { | |||
| if (cnode->size() == 0) { | |||
| MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString() << trace::DumpSourceLines(node); | |||
| } | |||
| return !IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), virtual_prims); | |||
| auto kernel_info = cnode->kernel_info(); | |||
| if (kernel_info) { | |||
| auto runtime_cache = kernel_info->runtime_cache(); | |||
| MS_EXCEPTION_IF_NULL(runtime_cache); | |||
| if (runtime_cache->is_real_kernel() != CacheBool::UNCACHED) { | |||
| return (runtime_cache->is_real_kernel() == CacheBool::TRUE); | |||
| } | |||
| } | |||
| bool res = !IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), virtual_prims); | |||
| if (kernel_info) { | |||
| auto runtime_cache = kernel_info->runtime_cache(); | |||
| MS_EXCEPTION_IF_NULL(runtime_cache); | |||
| if (res) { | |||
| runtime_cache->set_real_kernel(CacheBool::TRUE); | |||
| } else { | |||
| runtime_cache->set_real_kernel(CacheBool::FALSE); | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| bool AnfUtils::IsRealCNodeKernel(const AnfNodePtr &node) { | |||
| @@ -183,19 +204,38 @@ size_t AnfUtils::GetInputTensorNum(const AnfNodePtr &node) { | |||
| size_t AnfUtils::GetOutputTensorNum(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| if (kernel_info) { | |||
| auto runtime_cache = kernel_info->runtime_cache(); | |||
| if (runtime_cache->is_valid()) { | |||
| ssize_t output_tensor_num = runtime_cache->output_tensor_num(); | |||
| if (output_tensor_num >= 0) { | |||
| return static_cast<size_t>(output_tensor_num); | |||
| } | |||
| } | |||
| } | |||
| size_t res; | |||
| TypePtr type = node->Type(); | |||
| if (type == nullptr) { | |||
| return 0; | |||
| } | |||
| if (type->isa<Tuple>()) { | |||
| res = 0; | |||
| } else if (type->isa<Tuple>()) { | |||
| auto tuple_type = type->cast<TuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_type); | |||
| return tuple_type->size(); | |||
| res = tuple_type->size(); | |||
| } else if (type->isa<TypeNone>()) { | |||
| res = 0; | |||
| } else { | |||
| res = 1; | |||
| } | |||
| if (type->isa<TypeNone>()) { | |||
| return 0; | |||
| if (kernel_info) { | |||
| auto runtime_cache = kernel_info->runtime_cache(); | |||
| if (runtime_cache->is_valid()) { | |||
| runtime_cache->set_output_tensor_num(static_cast<ssize_t>(res)); | |||
| } | |||
| } | |||
| return 1; | |||
| return res; | |||
| } | |||
| std::pair<AnfNodePtr, size_t> AnfUtils::VisitKernel(const AnfNodePtr &anf_node, size_t index) { | |||
| @@ -29,7 +29,7 @@ | |||
| namespace mindspore { | |||
| enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSourceLineTipInLine = 2 }; | |||
| typedef enum CacheBool { UNCACHED = -1, FALSE, TRUE } CacheBool; | |||
| // Location class record the location in source code. | |||
| class Location { | |||
| public: | |||