| @@ -288,6 +288,7 @@ ComputingGraphHolder& get_computing_graph(std::shared_ptr<OpDef> compiled_op, Sm | |||||
| cg_holder.graph->options().async_exec_level = 0; | cg_holder.graph->options().async_exec_level = 0; | ||||
| cg_holder.graph->options().graph_opt_level = compiled_op->cast_final_safe<CompiledOp>().gopt_level; | cg_holder.graph->options().graph_opt_level = compiled_op->cast_final_safe<CompiledOp>().gopt_level; | ||||
| cg_holder.graph->options().enable_var_mem_defragment = false; | cg_holder.graph->options().enable_var_mem_defragment = false; | ||||
| cg_holder.graph->options().comp_seq_sync_device = false; | |||||
| cg_holder.graph->set_device_memory_allocator(cg_holder.allocator); | cg_holder.graph->set_device_memory_allocator(cg_holder.allocator); | ||||
| // cg_holder.graph->options().graph_opt.jit = 2; | // cg_holder.graph->options().graph_opt.jit = 2; | ||||
| VarNodeArray input_vars; | VarNodeArray input_vars; | ||||
| @@ -385,21 +385,27 @@ void ComputingGraphImpl::ComputingSequence::do_wait(bool explicit_user_wait) { | |||||
| } | } | ||||
| } | } | ||||
| for (auto cn : m_used_comp_node) { | |||||
| m_event_end.at(cn)->host_wait(); | |||||
| bool sync_device = m_owner_graph->options().comp_seq_sync_device; | |||||
| if (sync_device) { | |||||
| for (auto cn : m_used_comp_node) { | |||||
| m_event_end.at(cn)->host_wait(); | |||||
| } | |||||
| } | } | ||||
| m_wait_finished = true; | m_wait_finished = true; | ||||
| #if MGB_NEED_MEGDNN_ASYNC_ERROR | #if MGB_NEED_MEGDNN_ASYNC_ERROR | ||||
| // FIXME: It CAN NOT work well if more than one ComputingSequnces has been | // FIXME: It CAN NOT work well if more than one ComputingSequnces has been | ||||
| // executed on the same compnode and got AsyncError concurrently, because | // executed on the same compnode and got AsyncError concurrently, because | ||||
| // only the first async error on each comp_node would be recorded. | // only the first async error on each comp_node would be recorded. | ||||
| for (auto&& cn : m_used_comp_node) { | |||||
| auto error = cn.check_async_error(); | |||||
| if (error) { | |||||
| static_cast<const OperatorNodeExcExtraInfo*>(error->extra_info()) | |||||
| ->opr() | |||||
| ->owner_graph() | |||||
| ->record_async_error(std::move(error)); | |||||
| if (sync_device) { | |||||
| for (auto&& cn : m_used_comp_node) { | |||||
| auto error = cn.check_async_error(); | |||||
| if (error) { | |||||
| static_cast<const OperatorNodeExcExtraInfo*>(error->extra_info()) | |||||
| ->opr() | |||||
| ->owner_graph() | |||||
| ->record_async_error(std::move(error)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -520,6 +520,9 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||||
| */ | */ | ||||
| bool no_force_inplace = false; | bool no_force_inplace = false; | ||||
| //! whether to sync comp_node when waiting computing sequence | |||||
| bool comp_seq_sync_device = true; | |||||
| //! add extra deps for the comp seq if a specific var is dependent | //! add extra deps for the comp seq if a specific var is dependent | ||||
| ThinHashMap<VarNode*, VarNodeArray> extra_vardeps; | ThinHashMap<VarNode*, VarNodeArray> extra_vardeps; | ||||