| @@ -327,6 +327,8 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { | |||||
| runtime_instance->AssignStaticMemoryValueNode(child_graph.get()); | runtime_instance->AssignStaticMemoryValueNode(child_graph.get()); | ||||
| } | } | ||||
| bool AscendSession::IsSupportSummary() { return !device::KernelAdjust::NeedInsertSwitch(); } | |||||
| void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | ||||
| VectorRef *const outputs) { | VectorRef *const outputs) { | ||||
| MS_LOG(INFO) << "Start"; | MS_LOG(INFO) << "Start"; | ||||
| @@ -71,6 +71,7 @@ class AscendSession : public SessionBasic { | |||||
| GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | ||||
| GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; | GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; | ||||
| GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override; | GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override; | ||||
| bool IsSupportSummary() override; | |||||
| void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | ||||
| void BuildGraphImpl(GraphId) override; | void BuildGraphImpl(GraphId) override; | ||||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| @@ -1153,6 +1153,11 @@ void SessionBasic::Summary(KernelGraph *graph) { | |||||
| if (!exist_summary) { | if (!exist_summary) { | ||||
| return; | return; | ||||
| } | } | ||||
| if (!IsSupportSummary()) { | |||||
| MS_LOG(ERROR) << "The Summary operator can not collect data correctly. Detail: the data sink mode is used and the" | |||||
| " sink size(in model.train() python api) is not equal to 1."; | |||||
| } | |||||
| SetSummaryNodes(graph); | SetSummaryNodes(graph); | ||||
| auto summary_outputs = graph->summary_nodes(); | auto summary_outputs = graph->summary_nodes(); | ||||
| std::map<std::string, tensor::TensorPtr> params_list; | std::map<std::string, tensor::TensorPtr> params_list; | ||||
| @@ -132,6 +132,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| friend class RunGraphTask; | friend class RunGraphTask; | ||||
| friend class BuildOpTask; | friend class BuildOpTask; | ||||
| friend class RunOpTask; | friend class RunOpTask; | ||||
| virtual bool IsSupportSummary() { return true; } | |||||
| virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| VectorRef *outputs, | VectorRef *outputs, | ||||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node); | std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node); | ||||