| @@ -260,3 +260,15 @@ def replace_oprs(dst, oprmap): | |||||
| repl_dst_vec.push_back(j) | repl_dst_vec.push_back(j) | ||||
| return _mgb._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) | return _mgb._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) | ||||
| def set_priority_to_id(dest_vars): | |||||
| """For all oprs in the subgraph constructed by dest_vars | |||||
| set its priority to id if its original priority is zero | |||||
| :param dest_vars: target vars representing the graph | |||||
| """ | |||||
| dest_vec = _mgb._VectorSymbolVar() | |||||
| for i in dest_vars: | |||||
| assert isinstance(i, _mgb.SymbolVar) | |||||
| dest_vec.push_back(i) | |||||
| _mgb._set_priority_to_id(dest_vec) | |||||
| @@ -84,6 +84,8 @@ class trace: | |||||
| :param log_level: Log level. | :param log_level: Log level. | ||||
| :param sublinear_memory_config: Configuration for sublinear memory optimization. | :param sublinear_memory_config: Configuration for sublinear memory optimization. | ||||
| If not None, it enables sublinear memory optimization with given setting. | If not None, it enables sublinear memory optimization with given setting. | ||||
| :param allreduce_pack_max_size: Maximum size of an allreduce pack in MB. | |||||
| If not None, multiple gradients will be packed and synchronized together | |||||
| :param profiling: Whether to profile compiled trace. Default: False | :param profiling: Whether to profile compiled trace. Default: False | ||||
| """ | """ | ||||
| @@ -107,6 +109,7 @@ class trace: | |||||
| opt_level: int = None, | opt_level: int = None, | ||||
| log_level: int = None, | log_level: int = None, | ||||
| sublinear_memory_config: SublinearMemoryConfig = None, | sublinear_memory_config: SublinearMemoryConfig = None, | ||||
| allreduce_pack_max_size: int = None, | |||||
| profiling: bool = False | profiling: bool = False | ||||
| ): | ): | ||||
| self.__wrapped__ = func | self.__wrapped__ = func | ||||
| @@ -114,6 +117,7 @@ class trace: | |||||
| self._graph_opt_level = opt_level | self._graph_opt_level = opt_level | ||||
| self._log_level = log_level | self._log_level = log_level | ||||
| self._sublinear_memory_config = sublinear_memory_config | self._sublinear_memory_config = sublinear_memory_config | ||||
| self._allreduce_pack_max_size = allreduce_pack_max_size | |||||
| self._status = self._UNSTARTED | self._status = self._UNSTARTED | ||||
| self._args = None | self._args = None | ||||
| self._kwargs = None | self._kwargs = None | ||||
| @@ -313,6 +317,9 @@ class trace: | |||||
| "sublinear_mem_cofig.num_worker", | "sublinear_mem_cofig.num_worker", | ||||
| self._sublinear_memory_config.num_worker, | self._sublinear_memory_config.num_worker, | ||||
| ) | ) | ||||
| # pack allreduce | |||||
| if self._allreduce_pack_max_size is not None: | |||||
| cg.set_option("allreduce_pack_max_size", self._allreduce_pack_max_size) | |||||
| # profile | # profile | ||||
| if self._profiling: | if self._profiling: | ||||
| self._profiler = CompGraphProfiler(cg) | self._profiler = CompGraphProfiler(cg) | ||||
| @@ -391,6 +398,7 @@ class trace: | |||||
| outputs = [outputs] | outputs = [outputs] | ||||
| # _run_wrapped has checked validity of outputs | # _run_wrapped has checked validity of outputs | ||||
| self._sym_outputs = tuple(i._symvar for i in outputs) | self._sym_outputs = tuple(i._symvar for i in outputs) | ||||
| mgb.comp_graph_tools.set_priority_to_id(self._outspec) | |||||
| self._compiled_func = graph.get_default_graph().compile(None, self._outspec) | self._compiled_func = graph.get_default_graph().compile(None, self._outspec) | ||||
| def trace(self, *args: Tensor, **kwargs): | def trace(self, *args: Tensor, **kwargs): | ||||
| @@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta): | |||||
| :param loss: The obtained loss tensor | :param loss: The obtained loss tensor | ||||
| """ | """ | ||||
| rst = [] | rst = [] | ||||
| priority = 0 | |||||
| params = [] | params = [] | ||||
| for group in self.param_groups: | for group in self.param_groups: | ||||
| for param in group["params"]: | for param in group["params"]: | ||||
| @@ -180,14 +179,14 @@ class Optimizer(metaclass=ABCMeta): | |||||
| for param, grad in zip(params, grads): | for param, grad in zip(params, grads): | ||||
| if is_distributed(): | if is_distributed(): | ||||
| priority += 1 | |||||
| with opr_priority_scope(cg, -priority): | |||||
| # all_reduce_mean | |||||
| with opr_priority_scope(cg, -(2 ** 30)): | |||||
| # always run all_reduce_mean first except add_update | |||||
| grad = ( | grad = ( | ||||
| all_reduce_sum(grad, "grad_" + str(get_group_id())) | all_reduce_sum(grad, "grad_" + str(get_group_id())) | ||||
| / get_world_size() | / get_world_size() | ||||
| ) | ) | ||||
| with opr_priority_scope(cg, (1 << 30) - priority): | |||||
| with opr_priority_scope(cg, -(2 ** 31)): | |||||
| # always run add_update first | |||||
| grad_update = add_update(param.grad, grad) | grad_update = add_update(param.grad, grad) | ||||
| else: | else: | ||||
| grad_update = add_update(param.grad, grad) | grad_update = add_update(param.grad, grad) | ||||
| @@ -66,6 +66,8 @@ bool _config::set_comp_graph_option( | |||||
| SET_CG_OPTION(graph_opt.jit); | SET_CG_OPTION(graph_opt.jit); | ||||
| SET_CG_OPTION(graph_opt.tensorrt); | SET_CG_OPTION(graph_opt.tensorrt); | ||||
| SET_CG_OPTION(graph_opt_level); | SET_CG_OPTION(graph_opt_level); | ||||
| SET_CG_OPTION(allreduce_pack_max_size); | |||||
| SET_CG_OPTION(allreduce_pack_ignore_first); | |||||
| SET_CG_OPTION(var_sanity_check_first_run); | SET_CG_OPTION(var_sanity_check_first_run); | ||||
| SET_CG_OPTION(no_profiling_on_shape_change); | SET_CG_OPTION(no_profiling_on_shape_change); | ||||
| SET_CG_OPTION(allocate_static_mem_after_graph_compile); | SET_CG_OPTION(allocate_static_mem_after_graph_compile); | ||||
| @@ -1,3 +1,7 @@ | |||||
| %{ | |||||
| #include "megbrain/gopt/framework.h" | |||||
| %} | |||||
| %inline { | %inline { | ||||
| SymbolVarArray _get_owner_opr_inputs(SymbolVar var) { | SymbolVarArray _get_owner_opr_inputs(SymbolVar var) { | ||||
| @@ -35,5 +39,17 @@ | |||||
| } | } | ||||
| return mgb::cg::replace_oprs(vars, oprmap); | return mgb::cg::replace_oprs(vars, oprmap); | ||||
| } | } | ||||
| void _set_priority_to_id(const SymbolVarArray& dest_vars) { | |||||
| auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { | |||||
| if (opr->node_prop().attribute().priority == 0) { | |||||
| opr->node_prop().attribute().priority = opr->id(); | |||||
| } | |||||
| }; | |||||
| mgb::cg::DepOprIter dep_iter{on_opr}; | |||||
| for (const SymbolVar& var : dest_vars) { | |||||
| dep_iter.add(var); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| // vim: ft=swig foldmethod=marker foldmarker=f{{{,f}}} | // vim: ft=swig foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -441,12 +441,22 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||||
| optimizer.verbosity(options().log_level); | optimizer.verbosity(options().log_level); | ||||
| optimizer.enable_check_result(options().graph_opt_level < 0); | optimizer.enable_check_result(options().graph_opt_level < 0); | ||||
| if (sopr_stat.has_virtual_grad) { | if (sopr_stat.has_virtual_grad) { | ||||
| if (need_opt) | |||||
| if (need_opt) { | |||||
| #if MGB_ENABLE_OPR_MM | |||||
| optimizer.add_pass<gopt::PackAllReduceScanPass>(); | |||||
| #endif | |||||
| optimizer.add_preset_passes(false, nullptr, &options()); | optimizer.add_preset_passes(false, nullptr, &options()); | ||||
| } | |||||
| optimizer.add_pass<gopt::ExpandVirtualGradPass>(); | optimizer.add_pass<gopt::ExpandVirtualGradPass>(); | ||||
| } | } | ||||
| if (need_opt) | |||||
| if (need_opt) { | |||||
| optimizer.add_preset_passes(true, nullptr, &options()); | optimizer.add_preset_passes(true, nullptr, &options()); | ||||
| #if MGB_ENABLE_OPR_MM | |||||
| if (sopr_stat.has_virtual_grad) { | |||||
| optimizer.add_pass<gopt::PackAllReduceReplacePass>(); | |||||
| } | |||||
| #endif | |||||
| } | |||||
| optimizer.apply_inplace(dest_vars); | optimizer.apply_inplace(dest_vars); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -327,6 +327,18 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||||
| */ | */ | ||||
| int16_t graph_opt_level = 2; | int16_t graph_opt_level = 2; | ||||
| /*! | |||||
| * max size of allreduce packs in MB | |||||
| * set this option to zero to disable PackAllReducePass | |||||
| */ | |||||
| int16_t allreduce_pack_max_size = 0; | |||||
| /*! | |||||
| * do not pack the first n allreduces | |||||
| * PackAllReducePass disabled if allreduce_pack_max_size is zero | |||||
| */ | |||||
| int16_t allreduce_pack_ignore_first = 2; | |||||
| /*! | /*! | ||||
| * set logging level, larger number means more verbose | * set logging level, larger number means more verbose | ||||
| * 0: no log info | * 0: no log info | ||||
| @@ -183,7 +183,6 @@ SymbolVarArray replace_oprs( | |||||
| SymbolVarArray replace_vars_comp_graph( | SymbolVarArray replace_vars_comp_graph( | ||||
| const SymbolVarArray &dest, ComputingGraph* new_graph); | const SymbolVarArray &dest, ComputingGraph* new_graph); | ||||
| SymbolVarArray find_h2d(const SymbolVarArray& dest); | SymbolVarArray find_h2d(const SymbolVarArray& dest); | ||||
| /*! | /*! | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
| #include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
| #include "megbrain/serialization/opr_shallow_copy.h" | #include "megbrain/serialization/opr_shallow_copy.h" | ||||
| #include "../../core/impl/graph/cg_impl.h" | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| @@ -657,4 +658,309 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { | |||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| } | } | ||||
| #if MGB_ENABLE_OPR_MM | |||||
| #include "megbrain/opr/collective_comm.h" | |||||
| /* ======================= PackAllReduceScanPass ====================== */ | |||||
| const char* PackAllReduceScanPass::name() const { | |||||
| return "pack_allreduce_scan"; | |||||
| } | |||||
| void PackAllReduceScanPass::apply(OptState& opt) const { | |||||
| auto comp_graph = opt.graph().comp_graph(); | |||||
| if (comp_graph->options().allreduce_pack_max_size == 0) return; | |||||
| auto cb_scan = [this] (OperatorNodeBase* opr) { | |||||
| if (check_pattern(opr)) { | |||||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||||
| VarNode* target = comm.input(0)->owner_opr()->input(0); | |||||
| // only pack allreduces of grads of the same target | |||||
| // in case two allreduces depend on each other | |||||
| size_t id = target->id(); | |||||
| uint64_t hash = XXHash().update(&id, sizeof(size_t)).digest(); | |||||
| comm.set_pack_hash(hash); | |||||
| } | |||||
| }; | |||||
| opt.graph().iter(cb_scan); | |||||
| } | |||||
| bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { | |||||
| if (!opr->same_type<opr::CollectiveComm>()) return false; | |||||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||||
| if (comm.param().mode != opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM) return false; | |||||
| if (comm.input().size() != 1) return false; | |||||
| auto grad = comm.input(0)->owner_opr(); | |||||
| if (!grad->same_type<opr::VirtualGrad>()) return false; | |||||
| if (grad->input().size() != 2 or grad->output().size() != 1) return false; | |||||
| auto param = grad->input(1)->owner_opr(); | |||||
| if (!param->same_type<opr::SharedDeviceTensor>() and | |||||
| !param->same_type<opr::VolatileSharedDeviceTensor>()) return false; | |||||
| if (param->input().size() != 0) return false; | |||||
| return true; | |||||
| } | |||||
| /* ======================= PackAllReduceReplacePass ====================== */ | |||||
| const char* PackAllReduceReplacePass::name() const { | |||||
| return "pack_allreduce_replace"; | |||||
| } | |||||
| class PackAllReduceReplacePass::GroupInfo { | |||||
| public: | |||||
| GroupInfo(int _device, DType _dtype, | |||||
| size_t _nr_devices, bool _is_root, int _rank, | |||||
| std::shared_ptr<opr::GroupClient> _group_client, | |||||
| const std::string& _backend); | |||||
| uint64_t hash(uint64_t extra) const; | |||||
| int device; | |||||
| DType dtype; | |||||
| size_t nr_devices; | |||||
| bool is_root; | |||||
| int rank; | |||||
| std::shared_ptr<opr::GroupClient> group_client; | |||||
| std::string backend; | |||||
| }; | |||||
| PackAllReduceReplacePass::GroupInfo::GroupInfo( | |||||
| int _device, DType _dtype, | |||||
| size_t _nr_devices, bool _is_root, int _rank, | |||||
| std::shared_ptr<opr::GroupClient> _group_client, | |||||
| const std::string& _backend) : | |||||
| device(_device), dtype(_dtype), | |||||
| nr_devices(_nr_devices), is_root(_is_root), rank(_rank), | |||||
| group_client(_group_client), backend(_backend) { | |||||
| } | |||||
| uint64_t PackAllReduceReplacePass::GroupInfo::hash(uint64_t extra) const { | |||||
| DTypeEnum ev = dtype.enumv(); | |||||
| const std::string& server_addr = group_client->get_addr(); | |||||
| return XXHash() | |||||
| .update(&extra, sizeof(uint64_t)) | |||||
| .update(&device, sizeof(int)) | |||||
| .update(&ev, sizeof(DTypeEnum)) | |||||
| .update(&nr_devices, sizeof(size_t)) | |||||
| .update(&is_root, sizeof(bool)) | |||||
| .update(&rank, sizeof(int)) | |||||
| .update(server_addr.c_str(), server_addr.size()) | |||||
| .update(backend.c_str(), backend.size()) | |||||
| .digest(); | |||||
| } | |||||
| uint64_t PackAllReduceReplacePass::collect_groups(OperatorNodeBase* opr, | |||||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info, | |||||
| ThinHashMap<uint64_t, cg::OprNodeArray>& groups) { | |||||
| // check CollectiveComm oprs that have been marked in PackAllReduceScanPass | |||||
| if (!opr->same_type<opr::CollectiveComm>()) return 0; | |||||
| opr::CollectiveComm& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||||
| if (comm.pack_hash() == 0) return 0; // pack_hash not set | |||||
| VarNode* var = comm.input(0); | |||||
| auto info = std::make_shared<GroupInfo>( | |||||
| var->comp_node().locator().device, | |||||
| var->dtype(), | |||||
| comm.nr_devices(), | |||||
| comm.is_root(), | |||||
| comm.rank(), | |||||
| comm.group_client(), | |||||
| comm.backend() | |||||
| ); | |||||
| uint64_t hash = info->hash(comm.pack_hash()); | |||||
| if (group_info.find(hash) == group_info.end()) { | |||||
| group_info.emplace(hash, info); | |||||
| } | |||||
| groups[hash].push_back(opr); | |||||
| return hash; | |||||
| } | |||||
| void PackAllReduceReplacePass::divide_packs( | |||||
| const ThinHashMap<uint64_t, cg::OprNodeArray>& groups, | |||||
| ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs, | |||||
| size_t max_size) { | |||||
| cg::OprNodeArray pack; | |||||
| size_t sum = 0; | |||||
| for (auto it : groups) { | |||||
| uint64_t hash = it.first; | |||||
| const cg::OprNodeArray& group = it.second; | |||||
| for (size_t i = 0; i < group.size(); i++) { | |||||
| OperatorNodeBase* opr = group[i]; | |||||
| VarNode* var = opr->input(0); | |||||
| const TensorShape* shape = var->owner_graph() | |||||
| ->static_infer_manager().infer_shape_fallible(var); | |||||
| if (shape == nullptr) continue; | |||||
| pack.push_back(opr); | |||||
| sum += var->dtype().size(shape->total_nr_elems()); | |||||
| if (sum >= max_size) { | |||||
| if (pack.size() > 1) packs[hash].push_back(pack); | |||||
| pack.clear(); | |||||
| sum = 0; | |||||
| } | |||||
| } | |||||
| if (pack.size() > 1) packs[hash].push_back(pack); | |||||
| pack.clear(); | |||||
| sum = 0; | |||||
| } | |||||
| } | |||||
| void PackAllReduceReplacePass::insert_packed_oprs( | |||||
| size_t pack_id, | |||||
| const cg::OprNodeArray& pack, | |||||
| std::shared_ptr<GroupInfo> info, | |||||
| ThinHashMap<VarNode*, VarNode*>& replace_map, int priority) { | |||||
| // set priority | |||||
| mgb_assert(pack.size() > 0); | |||||
| auto graph = pack[0]->owner_graph(); | |||||
| auto on_opr_inserted = [priority] (const cg::event::OprInserted& event) { | |||||
| event.opr->node_prop().attribute().priority = priority; | |||||
| }; | |||||
| auto handler = graph->event().register_receiver<cg::event::OprInserted>(on_opr_inserted); | |||||
| // flatten inputs and record shapes and partition | |||||
| std::vector<SymbolVar> shapes; | |||||
| SymbolVarArray flattens; | |||||
| SymbolVarArray partition; | |||||
| for (size_t i = 0; i < pack.size(); i++) { | |||||
| VarNode* var = pack[i]->input(0); | |||||
| auto shape = opr::GetVarShape::make(SymbolVar(var)); | |||||
| shapes.push_back(shape); | |||||
| SymbolVar flatten = SymbolVar(var).flatten(); | |||||
| flattens.push_back(flatten); | |||||
| partition.push_back(opr::Reduce::make(shape, {opr::Reduce::Mode::PRODUCT, 0})); | |||||
| } | |||||
| // concat | |||||
| SymbolVar concat = opr::Concat::make(flattens, 0); | |||||
| // allreduce | |||||
| std::string key = ssprintf("grad_pack_%zu", pack_id); | |||||
| auto param = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||||
| SymbolVar allreduce = opr::CollectiveComm::make({concat}, graph, | |||||
| key, info->nr_devices, info->is_root, info->rank, | |||||
| info->group_client, param, info->dtype, info->backend)[0]; | |||||
| // split according to recorded partition | |||||
| SymbolVarArray splits = opr::Split::make(allreduce, | |||||
| opr::Split::Options::make_partition(0, partition)); | |||||
| // reshape and insert results into replace_map | |||||
| mgb_assert(pack.size() == splits.size()); | |||||
| for (size_t i = 0; i < pack.size(); i++) { | |||||
| VarNode* reshape = splits[i].reshape(shapes[i]).node(); | |||||
| replace_map[pack[i]->output(0)] = reshape; | |||||
| } | |||||
| } | |||||
| void PackAllReduceReplacePass::apply(OptState& opt) const { | |||||
| // get graph options | |||||
| auto comp_graph = opt.graph().comp_graph(); | |||||
| size_t max_size = comp_graph->options().allreduce_pack_max_size * 1024 * 1024; | |||||
| size_t ignore_first = comp_graph->options().allreduce_pack_ignore_first; | |||||
| if (max_size == 0) return; | |||||
| // get topo order | |||||
| auto& topo_sorter = static_cast<cg::ComputingGraphImpl*>(comp_graph)->topo_sorter(); | |||||
| cg::CompSeqExtraInfo extra_info; | |||||
| VarNodeArray endpoints = to_var_node_array(opt.graph().endpoint_vars()); | |||||
| const cg::OprNodeArray* seq = topo_sorter.get_comp_seq(extra_info, endpoints); | |||||
| topo_sorter.restore_opr_prop(); | |||||
| // collect allreduce groups from topo sequence | |||||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | |||||
| ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||||
| for (size_t i = 0; i < seq->size(); i++) { | |||||
| if (seq->at(i)->same_type<opr::CollectiveComm>()) { | |||||
| // ignore the first several allreduces | |||||
| if (ignore_first > 0) { | |||||
| --ignore_first; | |||||
| } else { | |||||
| collect_groups(seq->at(i), group_info, groups); | |||||
| } | |||||
| } | |||||
| } | |||||
| // divide groups into packs | |||||
| ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>> packs; | |||||
| divide_packs(groups, packs, max_size); | |||||
| // make sure that oprs inserted in this pass (reshape, concat, allreduce, | |||||
| // split, reshape) have higher priority than existing operators | |||||
| int priority = -seq->size() - 100; | |||||
| // insert packed operators and generate replace_map | |||||
| ThinHashMap<VarNode*, VarNode*> replace_map; | |||||
| size_t pack_id = 0; | |||||
| for (auto it : packs) { | |||||
| uint64_t hash = it.first; | |||||
| for (auto pack : it.second) { | |||||
| opt.call_with_opr(pack[0], [&]() { | |||||
| insert_packed_oprs(pack_id, pack, group_info[hash], replace_map, priority); | |||||
| }, OprPropertyFlag::NONE); | |||||
| pack_id += 1; | |||||
| } | |||||
| } | |||||
| // replace vars | |||||
| auto rewriter = opt.graph().make_rewriter(); | |||||
| auto cb_replace = [&](OperatorNodeBase* opr) { | |||||
| for (auto i : opr->input()) { | |||||
| auto iter = replace_map.find(i); | |||||
| if (iter != replace_map.end()) { | |||||
| rewriter.replace_var(i, iter->second, nullptr); | |||||
| } | |||||
| } | |||||
| rewriter.auto_replace_outputs(opr); | |||||
| }; | |||||
| opt.graph().iter(cb_replace); | |||||
| rewriter.apply_inplace(); | |||||
| } | |||||
| #else | |||||
| /* ======================= PackAllReduceScanPass ====================== */ | |||||
| const char* PackAllReduceScanPass::name() const { | |||||
| return "pack_allreduce_scan"; | |||||
| } | |||||
| void PackAllReduceScanPass::apply(OptState& opt) const { | |||||
| } | |||||
| bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { | |||||
| return true; | |||||
| } | |||||
| /* ======================= PackAllReduceReplacePass ====================== */ | |||||
| const char* PackAllReduceReplacePass::name() const { | |||||
| return "pack_allreduce_replace"; | |||||
| } | |||||
| void PackAllReduceReplacePass::apply(OptState& opt) const {} | |||||
| uint64_t PackAllReduceReplacePass::collect_groups( | |||||
| OperatorNodeBase* opr, | |||||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info, | |||||
| ThinHashMap<uint64_t, cg::OprNodeArray>& groups) { | |||||
| return 0; | |||||
| } | |||||
| void PackAllReduceReplacePass::divide_packs( | |||||
| const ThinHashMap<uint64_t, cg::OprNodeArray>& groups, | |||||
| ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs, | |||||
| size_t max_size) { | |||||
| } | |||||
| void PackAllReduceReplacePass::insert_packed_oprs( | |||||
| size_t pack_id, | |||||
| const cg::OprNodeArray& pack, | |||||
| std::shared_ptr<GroupInfo> info, | |||||
| ThinHashMap<VarNode*, VarNode*>& replace_map, int priority) { | |||||
| } | |||||
| #endif // MGB_ENABLE_OPR_MM | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -11,6 +11,8 @@ | |||||
| #pragma once | #pragma once | ||||
| #include <vector> | |||||
| #include "megbrain/gopt/framework.h" | #include "megbrain/gopt/framework.h" | ||||
| namespace mgb { | namespace mgb { | ||||
| @@ -90,6 +92,45 @@ namespace gopt { | |||||
| void apply(OptState& opt) const override; | void apply(OptState& opt) const override; | ||||
| }; | }; | ||||
| //! scan allreduces of param grads | |||||
| class PackAllReduceScanPass final : public Pass { | |||||
| public: | |||||
| const char* name() const override; | |||||
| void apply(OptState& opt) const override; | |||||
| private: | |||||
| // check pattern param -> grad -> allreduce | |||||
| static bool check_pattern(OperatorNodeBase* opr); | |||||
| }; | |||||
| //! pack allreduces of param grads | |||||
| class PackAllReduceReplacePass final : public Pass { | |||||
| public: | |||||
| class GroupInfo; | |||||
| const char* name() const override; | |||||
| void apply(OptState& opt) const override; | |||||
| // collect allreduces and divide into groups | |||||
| static uint64_t collect_groups( | |||||
| OperatorNodeBase* opr, | |||||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info, | |||||
| ThinHashMap<uint64_t, cg::OprNodeArray>& groups); | |||||
| // divide groups into packs, max_size in MB | |||||
| static void divide_packs( | |||||
| const ThinHashMap<uint64_t, cg::OprNodeArray>& groups, | |||||
| ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs, | |||||
| size_t max_size); | |||||
| // insert packed operators and update replace_map | |||||
| static void insert_packed_oprs( | |||||
| size_t pack_id, | |||||
| const cg::OprNodeArray& pack, | |||||
| std::shared_ptr<GroupInfo> info, | |||||
| ThinHashMap<VarNode*, VarNode*>& replace_map, int priority); | |||||
| }; | |||||
| } // namespace gopt | } // namespace gopt | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -14,6 +14,7 @@ | |||||
| #include "megbrain/gopt/basic_arith.h" | #include "megbrain/gopt/basic_arith.h" | ||||
| #include "megbrain/gopt/misc.h" | #include "megbrain/gopt/misc.h" | ||||
| #include "megbrain/opr/basic_arith_wrapper.h" | #include "megbrain/opr/basic_arith_wrapper.h" | ||||
| #include "megbrain/opr/blas.h" | |||||
| #include "megbrain/opr/cond.h" | #include "megbrain/opr/cond.h" | ||||
| #include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
| #include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
| @@ -410,4 +411,322 @@ TEST_PASS(RemoveRedundantTypeCvtPass, Basic) { | |||||
| check(x_q8_q8, x_q8_fp32_q8_); | check(x_q8_q8, x_q8_fp32_q8_); | ||||
| } | } | ||||
| #if MGB_ENABLE_OPR_MM | |||||
| #include "megbrain/opr/collective_comm.h" | |||||
| #include "../../opr-mm/test/mock_client.h" | |||||
| TEST_PASS(PackAllReduceScanPass, Basic) { | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().allreduce_pack_max_size = 5000; | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto cn = CompNode::load("gpux"); | |||||
| auto dev_x0 = std::make_shared<DeviceTensorND>(cn, TensorShape{3, 5}); | |||||
| auto dev_x1 = std::make_shared<DeviceTensorND>(cn, TensorShape{4, 6}); | |||||
| auto dev_y0 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}); | |||||
| auto dev_y1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}); | |||||
| auto x0 = opr::SharedDeviceTensor::make(*graph, dev_x0); | |||||
| auto x1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_x1); | |||||
| auto y0 = opr::SharedDeviceTensor::make(*graph, dev_y0); | |||||
| auto y1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_y1); | |||||
| auto grad0 = opr::VirtualGrad::make(y0, x0); | |||||
| auto grad1 = opr::VirtualGrad::make(y0, x1); | |||||
| auto grad2 = opr::VirtualGrad::make(y1, x0); | |||||
| auto grad3 = opr::VirtualGrad::make(y1, x1); | |||||
| auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||||
| auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), | |||||
| "grad0", 2, 0, 0, client, mode)[0]; | |||||
| auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), | |||||
| "grad1", 2, 0, 0, client, mode)[0]; | |||||
| auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), | |||||
| "grad2", 2, 0, 0, client, mode)[0]; | |||||
| auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), | |||||
| "grad3", 2, 0, 0, client, mode)[0]; | |||||
| gopt::GraphOptimizer() | |||||
| .add_pass<gopt::PackAllReduceScanPass>() | |||||
| .apply({{comm0, comm1, comm2, comm3}}); | |||||
| auto get_hash = [] (const SymbolVar& symvar) { | |||||
| cg::OperatorNodeBase* opr = symvar.node()->owner_opr(); | |||||
| return opr->cast_final_safe<opr::CollectiveComm>().pack_hash(); | |||||
| }; | |||||
| uint64_t hash0 = get_hash(comm0); | |||||
| uint64_t hash1 = get_hash(comm1); | |||||
| uint64_t hash2 = get_hash(comm2); | |||||
| uint64_t hash3 = get_hash(comm3); | |||||
| ASSERT_EQ(hash0, hash1); | |||||
| ASSERT_EQ(hash2, hash3); | |||||
| ASSERT_NE(hash0, hash2); | |||||
| } | |||||
| TEST_PASS(PackAllReduceReplacePass, CollectGroups) { | |||||
| REQUIRE_GPU(2); | |||||
| auto cns = load_multiple_xpus(2); | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 2; | |||||
| auto cli0 = std::make_shared<test::MockGroupClient>("mock_addr0"); | |||||
| auto cli1 = std::make_shared<test::MockGroupClient>("mock_addr1"); | |||||
| using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; | |||||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | |||||
| ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||||
| auto add_opr = [&] (const CompNode& cn, TensorShape shape, const DType& dt, | |||||
| std::shared_ptr<test::MockGroupClient> client, uint64_t extra_hash) { | |||||
| auto dev0 = std::make_shared<DeviceTensorND>(cn, shape, dt); | |||||
| auto wrt = opr::SharedDeviceTensor::make(*graph, dev0); | |||||
| auto dev1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}, dt); | |||||
| auto target = opr::SharedDeviceTensor::make(*graph, dev1); | |||||
| auto grad = opr::VirtualGrad::make(target, wrt); | |||||
| auto comm = opr::CollectiveComm::make( | |||||
| {grad}, graph.get(), "key", 2, 0, 0, client, | |||||
| opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0] | |||||
| .node()->owner_opr(); | |||||
| comm->cast_final_safe<opr::CollectiveComm>().set_pack_hash(extra_hash); | |||||
| return gopt::PackAllReduceReplacePass::collect_groups(comm, group_info, groups); | |||||
| }; | |||||
| uint64_t hash0 = add_opr(cns[0], TensorShape{1, 3}, dtype::Float32{}, cli0, 1); | |||||
| uint64_t hash1 = add_opr(cns[0], TensorShape{2, 4}, dtype::Float32{}, cli0, 1); // same | |||||
| uint64_t hash2 = add_opr(cns[1], TensorShape{3, 5}, dtype::Float32{}, cli0, 1); // comp_node | |||||
| uint64_t hash3 = add_opr(cns[0], TensorShape{4, 6}, dtype::Float16{}, cli0, 1); // dtype | |||||
| uint64_t hash4 = add_opr(cns[0], TensorShape{5, 7}, dtype::Float32{}, cli1, 1); // client | |||||
| uint64_t hash5 = add_opr(cns[0], TensorShape{6, 8}, dtype::Float32{}, cli0, 2); // extra_hash | |||||
| ASSERT_EQ(hash0, hash1); | |||||
| std::set<uint64_t> s; | |||||
| s.insert(hash0); | |||||
| s.insert(hash1); | |||||
| s.insert(hash2); | |||||
| s.insert(hash3); | |||||
| s.insert(hash4); | |||||
| s.insert(hash5); | |||||
| ASSERT_EQ(5, s.size()); | |||||
| ASSERT_EQ(1, group_info.count(hash0)); | |||||
| ASSERT_EQ(1, group_info.count(hash1)); | |||||
| ASSERT_EQ(1, group_info.count(hash2)); | |||||
| ASSERT_EQ(1, group_info.count(hash3)); | |||||
| ASSERT_EQ(1, group_info.count(hash4)); | |||||
| ASSERT_EQ(1, group_info.count(hash5)); | |||||
| ASSERT_EQ(2, groups[hash0].size()); | |||||
| ASSERT_EQ(2, groups[hash1].size()); | |||||
| ASSERT_EQ(1, groups[hash2].size()); | |||||
| ASSERT_EQ(1, groups[hash3].size()); | |||||
| ASSERT_EQ(1, groups[hash4].size()); | |||||
| ASSERT_EQ(1, groups[hash5].size()); | |||||
| } | |||||
| TEST_PASS(PackAllReduceReplacePass, DividePacks) { | |||||
| auto cn = CompNode::load("gpux"); | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||||
| ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||||
| ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>> packs; | |||||
| auto insert_opr = [&] (size_t size) { | |||||
| auto dev = std::make_shared<DeviceTensorND>(cn, TensorShape{size / sizeof(float)}); | |||||
| auto sd = opr::SharedDeviceTensor::make(*graph, dev); | |||||
| auto symvar = opr::CollectiveComm::make({sd}, graph.get(), | |||||
| "key", 2, 0, 0, client, mode)[0]; | |||||
| auto opr = symvar.node()->owner_opr(); | |||||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||||
| comm.set_pack_hash(1); | |||||
| return opr; | |||||
| }; | |||||
| auto pack_size = [&] (cg::OprNodeArray& pack) { | |||||
| size_t sum = 0; | |||||
| for (size_t i = 0; i < pack.size(); i++) { | |||||
| auto var = pack[i]->input(0); | |||||
| sum += var->dtype().size(var->shape().total_nr_elems()); | |||||
| } | |||||
| return sum; | |||||
| }; | |||||
| groups[0].push_back(insert_opr(100)); // group0, pack0, size=1100 | |||||
| groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100 | |||||
| groups[0].push_back(insert_opr(400)); // group0, pack0, size=1100 | |||||
| groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100 | |||||
| groups[0].push_back(insert_opr(500)); // group0, pack1, size=800 | |||||
| groups[0].push_back(insert_opr(200)); // group0, pack1, size=800 | |||||
| groups[0].push_back(insert_opr(100)); // group0, pack1, size=800 | |||||
| groups[1].push_back(insert_opr(100)); // group1, pack0, size=900 | |||||
| groups[1].push_back(insert_opr(400)); // group1, pack0, size=900 | |||||
| groups[1].push_back(insert_opr(300)); // group1, pack0, size=900 | |||||
| groups[1].push_back(insert_opr(100)); // group1, pack0, size=900 | |||||
| gopt::PackAllReduceReplacePass::divide_packs(groups, packs, 1000); | |||||
| ASSERT_EQ(2, packs.size()); | |||||
| ASSERT_EQ(2, packs[0].size()); | |||||
| ASSERT_EQ(4, packs[0][0].size()); | |||||
| ASSERT_EQ(1100, pack_size(packs[0][0])); | |||||
| ASSERT_EQ(3, packs[0][1].size()); | |||||
| ASSERT_EQ(800, pack_size(packs[0][1])); | |||||
| ASSERT_EQ(1, packs[1].size()); | |||||
| ASSERT_EQ(4, packs[1][0].size()); | |||||
| ASSERT_EQ(900, pack_size(packs[1][0])); | |||||
| } | |||||
| TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { | |||||
| auto cn = CompNode::load("gpux"); | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||||
| size_t nr_devices = 2; | |||||
| uint32_t rank = 0; | |||||
| uint32_t root = 0; | |||||
| using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; | |||||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | |||||
| ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||||
| auto insert_opr = [&] (const TensorShape& shape) { | |||||
| auto dev = std::make_shared<DeviceTensorND>(cn, shape); | |||||
| auto sd = opr::SharedDeviceTensor::make(*graph, dev); | |||||
| auto symvar = opr::CollectiveComm::make({sd}, graph.get(), | |||||
| "key", nr_devices, rank, root, client, mode)[0]; | |||||
| auto opr = symvar.node()->owner_opr(); | |||||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||||
| comm.set_pack_hash(1); | |||||
| gopt::PackAllReduceReplacePass::collect_groups(opr, group_info, groups); | |||||
| return symvar; | |||||
| }; | |||||
| auto shape_x = TensorShape{100, 200}; | |||||
| auto shape_y = TensorShape{200, 400}; | |||||
| auto x = insert_opr(shape_x); | |||||
| auto y = insert_opr(shape_y); | |||||
| ASSERT_EQ(1, group_info.size()); | |||||
| ASSERT_EQ(1, groups.size()); | |||||
| auto info = group_info.begin()->second; | |||||
| auto pack = groups.begin()->second; | |||||
| size_t pack_id = 0; | |||||
| ThinHashMap<VarNode*, VarNode*> replace_map; | |||||
| gopt::PackAllReduceReplacePass::insert_packed_oprs(pack_id, pack, info, replace_map, -1); | |||||
| auto grad_x = SymbolVar(x.node()->owner_opr()->input(0)); | |||||
| auto grad_y = SymbolVar(y.node()->owner_opr()->input(0)); | |||||
| auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0); | |||||
| std::string key = ssprintf("grad_pack_%zu", pack_id); | |||||
| auto allreduce = opr::CollectiveComm::make({concat}, graph.get(), | |||||
| key, nr_devices, rank, root, client, mode)[0]; | |||||
| std::vector<size_t> partition; | |||||
| partition.push_back(shape_x.total_nr_elems()); | |||||
| partition.push_back(shape_y.total_nr_elems()); | |||||
| auto splits = opr::Split::make(allreduce, | |||||
| opr::Split::Options::make_partition(allreduce, 0, partition)); | |||||
| ASSERT_EQ(2, splits.size()); | |||||
| auto dest_x = splits[0].reshape(shape_x); | |||||
| auto dest_y = splits[1].reshape(shape_y); | |||||
| ASSERT_EQ(2, replace_map.size()); | |||||
| ASSERT_TRUE(replace_map.count(x.node()) > 0); | |||||
| ASSERT_EQ(replace_map.at(x.node()), dest_x.node()); | |||||
| ASSERT_TRUE(replace_map.count(y.node()) > 0); | |||||
| ASSERT_EQ(replace_map.at(y.node()), dest_y.node()); | |||||
| } | |||||
| TEST_PASS(PackAllReduceReplacePass, Equivalence) { | |||||
| REQUIRE_GPU(2); | |||||
| auto cns = load_multiple_xpus(2); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto build_graph = [&] (uint32_t rank, std::shared_ptr<ComputingGraph> graph, | |||||
| SymbolVarArray& array) { | |||||
| HostTensorGenerator<> gen; | |||||
| auto cn = cns[rank]; | |||||
| auto host_x = gen({1, 1000}); | |||||
| auto host_y = gen({1000, 1}); | |||||
| auto dev_x = std::make_shared<DeviceTensorND>(cn); | |||||
| auto dev_y = std::make_shared<DeviceTensorND>(cn); | |||||
| dev_x->copy_from(*host_x).sync(); | |||||
| dev_y->copy_from(*host_y).sync(); | |||||
| auto x = opr::SharedDeviceTensor::make(*graph, dev_x); | |||||
| auto y = opr::VolatileSharedDeviceTensor::make(*graph, dev_y); | |||||
| auto loss = opr::MatrixMul::make(x, y).flatten(); | |||||
| auto grad_x = opr::VirtualGrad::make(loss, x); | |||||
| auto grad_y = opr::VirtualGrad::make(loss, y); | |||||
| using Mode = opr::CollectiveComm::Param::Mode; | |||||
| bool is_root = (rank == 0); | |||||
| auto reduced_x = opr::CollectiveComm::make({grad_x}, graph.get(), | |||||
| "x", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; | |||||
| auto reduced_y = opr::CollectiveComm::make({grad_y}, graph.get(), | |||||
| "y", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; | |||||
| graph->options().allreduce_pack_max_size = 5000; | |||||
| graph->options().allreduce_pack_ignore_first = 0; | |||||
| auto dest_vars = gopt::GraphOptimizer{} | |||||
| .add_pass<gopt::PackAllReduceScanPass>() | |||||
| .add_pass<gopt::PackAllReduceReplacePass>() | |||||
| .apply({{reduced_x, reduced_y}}).endpoint_vars(); | |||||
| array.emplace_back(reduced_x); | |||||
| array.emplace_back(reduced_y); | |||||
| array.emplace_back(dest_vars[0]); | |||||
| array.emplace_back(dest_vars[1]); | |||||
| }; | |||||
| auto run = [&] (uint32_t rank) { | |||||
| auto graph = ComputingGraph::make(); | |||||
| SymbolVarArray array; | |||||
| build_graph(rank, graph, array); | |||||
| HostTensorND host_reduced_x, host_reduced_y, host_dest_0, host_dest_1; | |||||
| graph->options().allreduce_pack_max_size = 0; | |||||
| auto func = graph->compile({make_callback_copy(array[0], host_reduced_x), | |||||
| make_callback_copy(array[1], host_reduced_y), | |||||
| make_callback_copy(array[2], host_dest_0), | |||||
| make_callback_copy(array[3], host_dest_1)}); | |||||
| func->execute(); | |||||
| MGB_ASSERT_TENSOR_EQ(host_reduced_x, host_dest_0); | |||||
| MGB_ASSERT_TENSOR_EQ(host_reduced_y, host_dest_1); | |||||
| }; | |||||
| std::thread t0(run, 0); | |||||
| std::thread t1(run, 1); | |||||
| t0.join(); | |||||
| t1.join(); | |||||
| } | |||||
| #endif // MGB_ENABLE_OPR_MM | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -461,16 +461,7 @@ void CollectiveComm::opr_register() { | |||||
| m_rank = reg_info.rank; | m_rank = reg_info.rank; | ||||
| m_root = reg_info.root_rank; | m_root = reg_info.root_rank; | ||||
| MegRayCommunicatorBuilder* builder; | |||||
| { | |||||
| static std::mutex user_data_mtx; | |||||
| std::unique_lock<std::mutex> lk(user_data_mtx); | |||||
| builder = owner_graph()->options().user_data | |||||
| .get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||||
| } | |||||
| m_megray_comm = builder->get_megray_comm( | |||||
| m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||||
| reg_info.hash, m_key, m_nr_devices, m_rank, | reg_info.hash, m_key, m_nr_devices, m_rank, | ||||
| get_megray_backend(m_backend), m_group_client); | get_megray_backend(m_backend), m_group_client); | ||||
| @@ -736,13 +727,15 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm( | |||||
| const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | ||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>(); | auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>(); | ||||
| return opr::CollectiveComm::make( | |||||
| auto new_opr = CollectiveComm::make( | |||||
| to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), | to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), | ||||
| opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), | opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), | ||||
| opr.group_client(), opr.dev_buffers(), opr.param(), | opr.group_client(), opr.dev_buffers(), opr.param(), | ||||
| opr.dtype(), opr.backend(), config)[0] | opr.dtype(), opr.backend(), config)[0] | ||||
| .node() | .node() | ||||
| ->owner_opr(); | ->owner_opr(); | ||||
| new_opr->cast_final_safe<opr::CollectiveComm>().set_pack_hash(opr.pack_hash()); | |||||
| return new_opr; | |||||
| } | } | ||||
| MGB_REG_OPR_SHALLOW_COPY(CollectiveComm, opr_shallow_copy_collective_mm); | MGB_REG_OPR_SHALLOW_COPY(CollectiveComm, opr_shallow_copy_collective_mm); | ||||
| @@ -54,13 +54,7 @@ void RemoteSend::scn_do_execute() { | |||||
| auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, | auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, | ||||
| comp_node.get_uid()); | comp_node.get_uid()); | ||||
| auto megray_comm_builder = | |||||
| owner_graph() | |||||
| ->options() | |||||
| .user_data | |||||
| .get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||||
| m_megray_comm = megray_comm_builder->get_megray_comm( | |||||
| m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||||
| reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | ||||
| m_init = true; | m_init = true; | ||||
| } | } | ||||
| @@ -158,13 +152,7 @@ void RemoteRecv::scn_do_execute() { | |||||
| m_peer.key, 2, false, 1, | m_peer.key, 2, false, 1, | ||||
| comp_node.get_uid()); | comp_node.get_uid()); | ||||
| auto megray_comm_builder = | |||||
| owner_graph() | |||||
| ->options() | |||||
| .user_data | |||||
| .get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||||
| m_megray_comm = megray_comm_builder->get_megray_comm( | |||||
| m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||||
| reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | ||||
| m_init = true; | m_init = true; | ||||
| } | } | ||||
| @@ -14,8 +14,8 @@ | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace opr; | using namespace opr; | ||||
| bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | |||||
| std::unique_lock<std::mutex> lk(m_mtx); | |||||
| bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | |||||
| std::unique_lock<std::mutex> lk(m_map_mtx); | |||||
| auto it = m_megray_comms.find(hash); | auto it = m_megray_comms.find(hash); | ||||
| if (it != m_megray_comms.end()) { | if (it != m_megray_comms.end()) { | ||||
| comm = it->second; | comm = it->second; | ||||
| @@ -24,27 +24,37 @@ bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Comm | |||||
| return false; | return false; | ||||
| } | } | ||||
| void MegRayCommunicatorBuilder::emplace(uint64_t hash, | |||||
| void MegRayCommBuilder::emplace(uint64_t hash, | |||||
| std::shared_ptr<MegRay::Communicator> comm) { | std::shared_ptr<MegRay::Communicator> comm) { | ||||
| std::unique_lock<std::mutex> lk(m_mtx); | |||||
| std::unique_lock<std::mutex> lk(m_map_mtx); | |||||
| m_megray_comms.emplace(hash, comm); | m_megray_comms.emplace(hash, comm); | ||||
| } | } | ||||
| std::shared_ptr<MegRay::Communicator> MegRayCommunicatorBuilder::get_megray_comm( | |||||
| std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||||
| uint64_t hash, std::string key, uint32_t size, uint32_t rank, | uint64_t hash, std::string key, uint32_t size, uint32_t rank, | ||||
| MegRay::Backend backend, | MegRay::Backend backend, | ||||
| std::shared_ptr<mgb::opr::GroupClient> group_client) { | std::shared_ptr<mgb::opr::GroupClient> group_client) { | ||||
| { | |||||
| // singleton pattern | |||||
| std::unique_lock<std::mutex> lk(sm_instance_mtx); | |||||
| if (sm_instance == nullptr) { | |||||
| sm_instance = new MegRayCommBuilder(); | |||||
| } | |||||
| } | |||||
| std::shared_ptr<MegRay::Communicator> comm; | std::shared_ptr<MegRay::Communicator> comm; | ||||
| if (!find(hash, comm)) { | |||||
| if (!sm_instance->find(hash, comm)) { | |||||
| comm = MegRay::get_communicator(size, rank, backend); | comm = MegRay::get_communicator(size, rank, backend); | ||||
| auto uid = comm->get_uid(); | auto uid = comm->get_uid(); | ||||
| auto uids = group_client->gather_uid(uid, key, size, rank); | auto uids = group_client->gather_uid(uid, key, size, rank); | ||||
| mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK); | mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK); | ||||
| emplace(hash, comm); | |||||
| sm_instance->emplace(hash, comm); | |||||
| } | } | ||||
| return comm; | return comm; | ||||
| } | } | ||||
| MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder); | |||||
| MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr; | |||||
| std::mutex MegRayCommBuilder::sm_instance_mtx; | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -81,6 +81,10 @@ public: | |||||
| return m_group_client; | return m_group_client; | ||||
| } | } | ||||
| void set_pack_hash(uint64_t hash) { m_pack_hash = hash; } | |||||
| uint64_t pack_hash() const { return m_pack_hash; } | |||||
| std::shared_ptr<MegRay::Context> megray_ctx() const { | std::shared_ptr<MegRay::Context> megray_ctx() const { | ||||
| return m_megray_ctx; | return m_megray_ctx; | ||||
| } | } | ||||
| @@ -123,6 +127,9 @@ private: | |||||
| // whose shape infer should be disabled *during* static infer phase. | // whose shape infer should be disabled *during* static infer phase. | ||||
| bool m_enable_shape_infer = false; | bool m_enable_shape_infer = false; | ||||
| //! set in PackAllReduceScanPass and used in PackAllReduceReplacePass | |||||
| uint64_t m_pack_hash = 0; | |||||
| std::shared_ptr<MegRay::Context> m_megray_ctx; | std::shared_ptr<MegRay::Context> m_megray_ctx; | ||||
| std::shared_ptr<MegRay::Communicator> m_megray_comm; | std::shared_ptr<MegRay::Communicator> m_megray_comm; | ||||
| bool m_init = false; | bool m_init = false; | ||||
| @@ -126,6 +126,8 @@ class GroupClient { | |||||
| virtual ~GroupClient() = default; | virtual ~GroupClient() = default; | ||||
| public: | public: | ||||
| virtual const std::string& get_addr() const = 0; | |||||
| virtual GroupManager::RegisterInfo opr_register(const std::string& key, | virtual GroupManager::RegisterInfo opr_register(const std::string& key, | ||||
| size_t nr_devices, | size_t nr_devices, | ||||
| bool is_root, int rank, | bool is_root, int rank, | ||||
| @@ -23,18 +23,19 @@ namespace opr { | |||||
| /*! | /*! | ||||
| * gather MegRay unique ids and build communicator, use hash for deduplication | * gather MegRay unique ids and build communicator, use hash for deduplication | ||||
| */ | */ | ||||
| class MegRayCommunicatorBuilder final : public mgb::UserDataContainer::UserData { | |||||
| MGB_TYPEINFO_OBJ_DECL; | |||||
| class MegRayCommBuilder { | |||||
| private: | private: | ||||
| bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm); | bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm); | ||||
| void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | ||||
| std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms; | std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms; | ||||
| std::mutex m_mtx; | |||||
| std::mutex m_map_mtx; | |||||
| static MegRayCommBuilder* sm_instance; | |||||
| static std::mutex sm_instance_mtx; | |||||
| public: | public: | ||||
| std::shared_ptr<MegRay::Communicator> get_megray_comm( | |||||
| static std::shared_ptr<MegRay::Communicator> get_megray_comm( | |||||
| uint64_t hash, std::string key, uint32_t size, uint32_t rank, | uint64_t hash, std::string key, uint32_t size, uint32_t rank, | ||||
| MegRay::Backend backend, | MegRay::Backend backend, | ||||
| std::shared_ptr<mgb::opr::GroupClient> group_client); | std::shared_ptr<mgb::opr::GroupClient> group_client); | ||||
| @@ -47,7 +47,7 @@ public: | |||||
| uint32_t group_barrier(uint32_t size, uint32_t rank) override; | uint32_t group_barrier(uint32_t size, uint32_t rank) override; | ||||
| const std::string& get_addr() const { | |||||
| const std::string& get_addr() const override { | |||||
| return m_addr; | return m_addr; | ||||
| } | } | ||||
| @@ -17,11 +17,10 @@ | |||||
| #include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
| #include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
| #include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
| #include "mock_client.h" | |||||
| using namespace mgb; | using namespace mgb; | ||||
| namespace { | |||||
| using Mode = opr::CollectiveComm::Param::Mode; | using Mode = opr::CollectiveComm::Param::Mode; | ||||
| SymbolVar make_all_reduce_output(const Mode mode, | SymbolVar make_all_reduce_output(const Mode mode, | ||||
| @@ -41,41 +40,6 @@ SymbolVarArray make_reduce_scatter_sum_output(const SymbolVarArray& inputs) { | |||||
| rdc, opr::Split::Options::make_average(0, inputs.size())); | rdc, opr::Split::Options::make_average(0, inputs.size())); | ||||
| } | } | ||||
| class MockGroupClient final : public opr::GroupClient { | |||||
| public: | |||||
| ~MockGroupClient() override = default; | |||||
| opr::GroupManager::RegisterInfo opr_register(const std::string& key, | |||||
| size_t nr_devices, | |||||
| bool is_root, int rank, | |||||
| uintptr_t stream) { | |||||
| return m_mgr.opr_register(key, nr_devices, is_root, rank, stream); | |||||
| } | |||||
| std::vector<std::string> gather_uid(const std::string& uid, | |||||
| const std::string& key, uint32_t size, uint32_t rank) { | |||||
| return m_mgr.gather_uid(uid, key, size, rank); | |||||
| } | |||||
| void set_output_shape(const std::string& key, | |||||
| const TensorShape& shape) override { | |||||
| m_mgr.set_output_shape(key, shape); | |||||
| } | |||||
| TensorShape get_output_shape(const std::string& key) override { | |||||
| return m_mgr.get_output_shape(key); | |||||
| } | |||||
| uint32_t group_barrier(uint32_t size, uint32_t rank) override { | |||||
| return m_mgr.group_barrier(size, rank); | |||||
| } | |||||
| private: | |||||
| opr::GroupManager m_mgr; | |||||
| }; | |||||
| } // namespace | |||||
| TEST(TestOprCollectiveComm, AllReduce) { | TEST(TestOprCollectiveComm, AllReduce) { | ||||
| REQUIRE_GPU(2); | REQUIRE_GPU(2); | ||||
| @@ -88,7 +52,7 @@ TEST(TestOprCollectiveComm, AllReduce) { | |||||
| auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
| HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
| auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | ||||
| @@ -126,7 +90,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { | |||||
| auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
| HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto run_0 = [&]() { | auto run_0 = [&]() { | ||||
| auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
| @@ -187,7 +151,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { | |||||
| HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
| HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; | HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
| auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
| @@ -268,7 +232,7 @@ TEST(TestOprCollectiveComm, AllGather) { | |||||
| auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
| HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
| auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | ||||
| @@ -300,7 +264,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) { | |||||
| auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
| HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
| auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
| @@ -356,7 +320,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { | |||||
| HostTensorND host_out_grad0, host_out_grad1; | HostTensorND host_out_grad0, host_out_grad1; | ||||
| HostTensorND host_out_grad0_expect, host_out_grad1_expect; | HostTensorND host_out_grad0_expect, host_out_grad1_expect; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
| auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
| @@ -438,7 +402,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSum) { | |||||
| auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
| HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
| auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | ||||
| @@ -471,7 +435,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) { | |||||
| auto host_x1 = gen({8}); | auto host_x1 = gen({8}); | ||||
| HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
| auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
| @@ -528,7 +492,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { | |||||
| HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | ||||
| HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; | HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
| auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
| @@ -610,7 +574,7 @@ TEST(TestOprCollectiveComm, ReduceSum) { | |||||
| auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
| HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
| auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | ||||
| @@ -641,7 +605,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) { | |||||
| auto host_x1 = gen({28, 28}); | auto host_x1 = gen({28, 28}); | ||||
| HostTensorND host_y0, host_y_expect; | HostTensorND host_y0, host_y_expect; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
| auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
| @@ -694,7 +658,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { | |||||
| HostTensorND host_y0, host_y0_expect, host_out_grad0, host_out_grad1; | HostTensorND host_y0, host_y0_expect, host_out_grad0, host_out_grad1; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
| auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
| @@ -764,7 +728,7 @@ TEST(TestOprCollectiveComm, Broadcast) { | |||||
| auto host_x0 = gen({28, 28}); | auto host_x0 = gen({28, 28}); | ||||
| HostTensorND host_y0, host_y1, host_y_expect; | HostTensorND host_y0, host_y1, host_y_expect; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
| auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | ||||
| @@ -794,7 +758,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) { | |||||
| auto host_x0 = gen({28, 28}); | auto host_x0 = gen({28, 28}); | ||||
| HostTensorND host_y0, host_y1; | HostTensorND host_y0, host_y1; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
| auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
| @@ -840,7 +804,7 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { | |||||
| HostTensorND host_y0, host_y1, host_out_grad, host_out_grad_expect; | HostTensorND host_y0, host_y1, host_out_grad, host_out_grad_expect; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto run_0 = [&]() { // rank 0 | auto run_0 = [&]() { // rank 0 | ||||
| auto graph0 = ComputingGraph::make(); | auto graph0 = ComputingGraph::make(); | ||||
| @@ -14,51 +14,14 @@ | |||||
| #include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
| #include "megbrain/system.h" | #include "megbrain/system.h" | ||||
| #include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
| #include "mock_client.h" | |||||
| #include <thread> | #include <thread> | ||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace opr; | |||||
| namespace { | |||||
| class MockGroupClient final : public opr::GroupClient { | |||||
| public: | |||||
| ~MockGroupClient() override = default; | |||||
| opr::GroupManager::RegisterInfo opr_register(const std::string& key, | |||||
| size_t nr_devices, | |||||
| bool is_root, int rank, | |||||
| uint64_t comp_node_hash) { | |||||
| return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); | |||||
| } | |||||
| std::vector<std::string> gather_uid(const std::string& uid, | |||||
| const std::string& key, uint32_t size, uint32_t rank) { | |||||
| return m_mgr.gather_uid(uid, key, size, rank); | |||||
| } | |||||
| void set_output_shape(const std::string& key, | |||||
| const TensorShape& shape) override { | |||||
| m_mgr.set_output_shape(key, shape); | |||||
| } | |||||
| TensorShape get_output_shape(const std::string& key) override { | |||||
| return m_mgr.get_output_shape(key); | |||||
| } | |||||
| uint32_t group_barrier(uint32_t size, uint32_t rank) override { | |||||
| return m_mgr.group_barrier(size, rank); | |||||
| } | |||||
| private: | |||||
| opr::GroupManager m_mgr; | |||||
| }; | |||||
| const auto send_tag = RemoteIOBase::Type::SEND; | |||||
| const auto recv_tag = RemoteIOBase::Type::RECV; | |||||
| } // anonymous namespace | |||||
| const auto send_tag = opr::RemoteIOBase::Type::SEND; | |||||
| const auto recv_tag = opr::RemoteIOBase::Type::RECV; | |||||
| TEST(TestOprIORemote, Identity) { | TEST(TestOprIORemote, Identity) { | ||||
| REQUIRE_GPU(2); | REQUIRE_GPU(2); | ||||
| @@ -69,7 +32,7 @@ TEST(TestOprIORemote, Identity) { | |||||
| auto host_x = gen({28, 28}); | auto host_x = gen({28, 28}); | ||||
| HostTensorND host_y; | HostTensorND host_y; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | ||||
| @@ -90,7 +53,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| auto host_x = gen({2, 3}, cns[1]); | auto host_x = gen({2, 3}, cns[1]); | ||||
| HostTensorND host_x_get; | HostTensorND host_x_get; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto sender = [&]() { | auto sender = [&]() { | ||||
| auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
| @@ -123,7 +86,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| auto host_x = gen({2, 3}, cns[0]); | auto host_x = gen({2, 3}, cns[0]); | ||||
| HostTensorND host_x_get; | HostTensorND host_x_get; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto sender = [&]() { | auto sender = [&]() { | ||||
| sys::set_thread_name("sender"); | sys::set_thread_name("sender"); | ||||
| @@ -157,7 +120,7 @@ TEST(TestOprIORemote, APlusB) { | |||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| auto host_x = gen({5, 7}, cns[0]), host_y = gen({5, 1}, cns[0]); | auto host_x = gen({5, 7}, cns[0]), host_y = gen({5, 1}, cns[0]); | ||||
| HostTensorND host_z; | HostTensorND host_z; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto sender = [&]() { | auto sender = [&]() { | ||||
| auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
| @@ -208,7 +171,7 @@ TEST(TestOprIORemote, SendGrad) { | |||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| auto host_x = gen({2, 3}, cns[0]); | auto host_x = gen({2, 3}, cns[0]); | ||||
| HostTensorND host_gx, host_loss; | HostTensorND host_gx, host_loss; | ||||
| auto client = std::make_shared<MockGroupClient>(); | |||||
| auto client = std::make_shared<test::MockGroupClient>(); | |||||
| auto sender = [&]() { | auto sender = [&]() { | ||||
| sys::set_thread_name("sender"); | sys::set_thread_name("sender"); | ||||
| @@ -0,0 +1,62 @@ | |||||
| /** | |||||
| * \file src/opr-mm/test/mock_client.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megbrain/opr/group_manager.h" | |||||
| namespace mgb { | |||||
| namespace test { | |||||
| class MockGroupClient final : public opr::GroupClient { | |||||
| public: | |||||
| using RegisterInfo = opr::GroupManager::RegisterInfo; | |||||
| MockGroupClient(const std::string& server_addr = "mock_addr") : | |||||
| m_addr(server_addr) { | |||||
| } | |||||
| ~MockGroupClient() override = default; | |||||
| const std::string& get_addr() const { | |||||
| return m_addr; | |||||
| } | |||||
| RegisterInfo opr_register(const std::string& key, size_t nr_devices, | |||||
| bool is_root, int rank, uint64_t comp_node_hash) { | |||||
| return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); | |||||
| } | |||||
| std::vector<std::string> gather_uid(const std::string& uid, | |||||
| const std::string& key, uint32_t size, uint32_t rank) { | |||||
| return m_mgr.gather_uid(uid, key, size, rank); | |||||
| } | |||||
| void set_output_shape(const std::string& key, | |||||
| const TensorShape& shape) override { | |||||
| m_mgr.set_output_shape(key, shape); | |||||
| } | |||||
| TensorShape get_output_shape(const std::string& key) override { | |||||
| return m_mgr.get_output_shape(key); | |||||
| } | |||||
| uint32_t group_barrier(uint32_t size, uint32_t rank) override { | |||||
| return m_mgr.group_barrier(size, rank); | |||||
| } | |||||
| private: | |||||
| const std::string m_addr; | |||||
| opr::GroupManager m_mgr; | |||||
| }; | |||||
| } // namespace test | |||||
| } // namespace mgb | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||