| @@ -260,3 +260,15 @@ def replace_oprs(dst, oprmap): | |||
| repl_dst_vec.push_back(j) | |||
| return _mgb._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) | |||
| def set_priority_to_id(dest_vars): | |||
| """For all oprs in the subgraph constructed by dest_vars | |||
| set its priority to id if its original priority is zero | |||
| :param dest_vars: target vars representing the graph | |||
| """ | |||
| dest_vec = _mgb._VectorSymbolVar() | |||
| for i in dest_vars: | |||
| assert isinstance(i, _mgb.SymbolVar) | |||
| dest_vec.push_back(i) | |||
| _mgb._set_priority_to_id(dest_vec) | |||
| @@ -84,6 +84,8 @@ class trace: | |||
| :param log_level: Log level. | |||
| :param sublinear_memory_config: Configuration for sublinear memory optimization. | |||
| If not None, it enables sublinear memory optimization with given setting. | |||
| :param allreduce_pack_max_size: Maximum size of an allreduce pack in MB. | |||
| If not None, multiple gradients will be packed and synchronized together | |||
| :param profiling: Whether to profile compiled trace. Default: False | |||
| """ | |||
| @@ -107,6 +109,7 @@ class trace: | |||
| opt_level: int = None, | |||
| log_level: int = None, | |||
| sublinear_memory_config: SublinearMemoryConfig = None, | |||
| allreduce_pack_max_size: int = None, | |||
| profiling: bool = False | |||
| ): | |||
| self.__wrapped__ = func | |||
| @@ -114,6 +117,7 @@ class trace: | |||
| self._graph_opt_level = opt_level | |||
| self._log_level = log_level | |||
| self._sublinear_memory_config = sublinear_memory_config | |||
| self._allreduce_pack_max_size = allreduce_pack_max_size | |||
| self._status = self._UNSTARTED | |||
| self._args = None | |||
| self._kwargs = None | |||
| @@ -313,6 +317,9 @@ class trace: | |||
| "sublinear_mem_cofig.num_worker", | |||
| self._sublinear_memory_config.num_worker, | |||
| ) | |||
| # pack allreduce | |||
| if self._allreduce_pack_max_size is not None: | |||
| cg.set_option("allreduce_pack_max_size", self._allreduce_pack_max_size) | |||
| # profile | |||
| if self._profiling: | |||
| self._profiler = CompGraphProfiler(cg) | |||
| @@ -391,6 +398,7 @@ class trace: | |||
| outputs = [outputs] | |||
| # _run_wrapped has checked validity of outputs | |||
| self._sym_outputs = tuple(i._symvar for i in outputs) | |||
| mgb.comp_graph_tools.set_priority_to_id(self._outspec) | |||
| self._compiled_func = graph.get_default_graph().compile(None, self._outspec) | |||
| def trace(self, *args: Tensor, **kwargs): | |||
| @@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta): | |||
| :param loss: The obtained loss tensor | |||
| """ | |||
| rst = [] | |||
| priority = 0 | |||
| params = [] | |||
| for group in self.param_groups: | |||
| for param in group["params"]: | |||
| @@ -180,14 +179,14 @@ class Optimizer(metaclass=ABCMeta): | |||
| for param, grad in zip(params, grads): | |||
| if is_distributed(): | |||
| priority += 1 | |||
| with opr_priority_scope(cg, -priority): | |||
| # all_reduce_mean | |||
| with opr_priority_scope(cg, -(2 ** 30)): | |||
| # always run all_reduce_mean first except add_update | |||
| grad = ( | |||
| all_reduce_sum(grad, "grad_" + str(get_group_id())) | |||
| / get_world_size() | |||
| ) | |||
| with opr_priority_scope(cg, (1 << 30) - priority): | |||
| with opr_priority_scope(cg, -(2 ** 31)): | |||
| # always run add_update first | |||
| grad_update = add_update(param.grad, grad) | |||
| else: | |||
| grad_update = add_update(param.grad, grad) | |||
| @@ -66,6 +66,8 @@ bool _config::set_comp_graph_option( | |||
| SET_CG_OPTION(graph_opt.jit); | |||
| SET_CG_OPTION(graph_opt.tensorrt); | |||
| SET_CG_OPTION(graph_opt_level); | |||
| SET_CG_OPTION(allreduce_pack_max_size); | |||
| SET_CG_OPTION(allreduce_pack_ignore_first); | |||
| SET_CG_OPTION(var_sanity_check_first_run); | |||
| SET_CG_OPTION(no_profiling_on_shape_change); | |||
| SET_CG_OPTION(allocate_static_mem_after_graph_compile); | |||
| @@ -1,3 +1,7 @@ | |||
| %{ | |||
| #include "megbrain/gopt/framework.h" | |||
| %} | |||
| %inline { | |||
| SymbolVarArray _get_owner_opr_inputs(SymbolVar var) { | |||
| @@ -35,5 +39,17 @@ | |||
| } | |||
| return mgb::cg::replace_oprs(vars, oprmap); | |||
| } | |||
| void _set_priority_to_id(const SymbolVarArray& dest_vars) { | |||
| auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { | |||
| if (opr->node_prop().attribute().priority == 0) { | |||
| opr->node_prop().attribute().priority = opr->id(); | |||
| } | |||
| }; | |||
| mgb::cg::DepOprIter dep_iter{on_opr}; | |||
| for (const SymbolVar& var : dest_vars) { | |||
| dep_iter.add(var); | |||
| } | |||
| } | |||
| } | |||
| // vim: ft=swig foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -441,12 +441,22 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||
| optimizer.verbosity(options().log_level); | |||
| optimizer.enable_check_result(options().graph_opt_level < 0); | |||
| if (sopr_stat.has_virtual_grad) { | |||
| if (need_opt) | |||
| if (need_opt) { | |||
| #if MGB_ENABLE_OPR_MM | |||
| optimizer.add_pass<gopt::PackAllReduceScanPass>(); | |||
| #endif | |||
| optimizer.add_preset_passes(false, nullptr, &options()); | |||
| } | |||
| optimizer.add_pass<gopt::ExpandVirtualGradPass>(); | |||
| } | |||
| if (need_opt) | |||
| if (need_opt) { | |||
| optimizer.add_preset_passes(true, nullptr, &options()); | |||
| #if MGB_ENABLE_OPR_MM | |||
| if (sopr_stat.has_virtual_grad) { | |||
| optimizer.add_pass<gopt::PackAllReduceReplacePass>(); | |||
| } | |||
| #endif | |||
| } | |||
| optimizer.apply_inplace(dest_vars); | |||
| } | |||
| #endif | |||
| @@ -327,6 +327,18 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||
| */ | |||
| int16_t graph_opt_level = 2; | |||
| /*! | |||
| * max size of allreduce packs in MB | |||
| * set this option to zero to disable PackAllReducePass | |||
| */ | |||
| int16_t allreduce_pack_max_size = 0; | |||
| /*! | |||
| * do not pack the first n allreduces | |||
| * PackAllReducePass disabled if allreduce_pack_max_size is zero | |||
| */ | |||
| int16_t allreduce_pack_ignore_first = 2; | |||
| /*! | |||
| * set logging level, larger number means more verbose | |||
| * 0: no log info | |||
| @@ -183,7 +183,6 @@ SymbolVarArray replace_oprs( | |||
| SymbolVarArray replace_vars_comp_graph( | |||
| const SymbolVarArray &dest, ComputingGraph* new_graph); | |||
| SymbolVarArray find_h2d(const SymbolVarArray& dest); | |||
| /*! | |||
| @@ -17,6 +17,7 @@ | |||
| #include "megbrain/opr/utility.h" | |||
| #include "megbrain/serialization/serializer.h" | |||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||
| #include "../../core/impl/graph/cg_impl.h" | |||
| using namespace mgb; | |||
| using namespace gopt; | |||
| @@ -657,4 +658,309 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { | |||
| rewriter.apply_inplace(); | |||
| } | |||
| #if MGB_ENABLE_OPR_MM | |||
| #include "megbrain/opr/collective_comm.h" | |||
| /* ======================= PackAllReduceScanPass ====================== */ | |||
| const char* PackAllReduceScanPass::name() const { | |||
| return "pack_allreduce_scan"; | |||
| } | |||
| void PackAllReduceScanPass::apply(OptState& opt) const { | |||
| auto comp_graph = opt.graph().comp_graph(); | |||
| if (comp_graph->options().allreduce_pack_max_size == 0) return; | |||
| auto cb_scan = [this] (OperatorNodeBase* opr) { | |||
| if (check_pattern(opr)) { | |||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||
| VarNode* target = comm.input(0)->owner_opr()->input(0); | |||
| // only pack allreduces of grads of the same target | |||
| // in case two allreduces depend on each other | |||
| size_t id = target->id(); | |||
| uint64_t hash = XXHash().update(&id, sizeof(size_t)).digest(); | |||
| comm.set_pack_hash(hash); | |||
| } | |||
| }; | |||
| opt.graph().iter(cb_scan); | |||
| } | |||
| bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { | |||
| if (!opr->same_type<opr::CollectiveComm>()) return false; | |||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||
| if (comm.param().mode != opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM) return false; | |||
| if (comm.input().size() != 1) return false; | |||
| auto grad = comm.input(0)->owner_opr(); | |||
| if (!grad->same_type<opr::VirtualGrad>()) return false; | |||
| if (grad->input().size() != 2 or grad->output().size() != 1) return false; | |||
| auto param = grad->input(1)->owner_opr(); | |||
| if (!param->same_type<opr::SharedDeviceTensor>() and | |||
| !param->same_type<opr::VolatileSharedDeviceTensor>()) return false; | |||
| if (param->input().size() != 0) return false; | |||
| return true; | |||
| } | |||
| /* ======================= PackAllReduceReplacePass ====================== */ | |||
| const char* PackAllReduceReplacePass::name() const { | |||
| return "pack_allreduce_replace"; | |||
| } | |||
| class PackAllReduceReplacePass::GroupInfo { | |||
| public: | |||
| GroupInfo(int _device, DType _dtype, | |||
| size_t _nr_devices, bool _is_root, int _rank, | |||
| std::shared_ptr<opr::GroupClient> _group_client, | |||
| const std::string& _backend); | |||
| uint64_t hash(uint64_t extra) const; | |||
| int device; | |||
| DType dtype; | |||
| size_t nr_devices; | |||
| bool is_root; | |||
| int rank; | |||
| std::shared_ptr<opr::GroupClient> group_client; | |||
| std::string backend; | |||
| }; | |||
| PackAllReduceReplacePass::GroupInfo::GroupInfo( | |||
| int _device, DType _dtype, | |||
| size_t _nr_devices, bool _is_root, int _rank, | |||
| std::shared_ptr<opr::GroupClient> _group_client, | |||
| const std::string& _backend) : | |||
| device(_device), dtype(_dtype), | |||
| nr_devices(_nr_devices), is_root(_is_root), rank(_rank), | |||
| group_client(_group_client), backend(_backend) { | |||
| } | |||
| uint64_t PackAllReduceReplacePass::GroupInfo::hash(uint64_t extra) const { | |||
| DTypeEnum ev = dtype.enumv(); | |||
| const std::string& server_addr = group_client->get_addr(); | |||
| return XXHash() | |||
| .update(&extra, sizeof(uint64_t)) | |||
| .update(&device, sizeof(int)) | |||
| .update(&ev, sizeof(DTypeEnum)) | |||
| .update(&nr_devices, sizeof(size_t)) | |||
| .update(&is_root, sizeof(bool)) | |||
| .update(&rank, sizeof(int)) | |||
| .update(server_addr.c_str(), server_addr.size()) | |||
| .update(backend.c_str(), backend.size()) | |||
| .digest(); | |||
| } | |||
| uint64_t PackAllReduceReplacePass::collect_groups(OperatorNodeBase* opr, | |||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info, | |||
| ThinHashMap<uint64_t, cg::OprNodeArray>& groups) { | |||
| // check CollectiveComm oprs that have been marked in PackAllReduceScanPass | |||
| if (!opr->same_type<opr::CollectiveComm>()) return 0; | |||
| opr::CollectiveComm& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||
| if (comm.pack_hash() == 0) return 0; // pack_hash not set | |||
| VarNode* var = comm.input(0); | |||
| auto info = std::make_shared<GroupInfo>( | |||
| var->comp_node().locator().device, | |||
| var->dtype(), | |||
| comm.nr_devices(), | |||
| comm.is_root(), | |||
| comm.rank(), | |||
| comm.group_client(), | |||
| comm.backend() | |||
| ); | |||
| uint64_t hash = info->hash(comm.pack_hash()); | |||
| if (group_info.find(hash) == group_info.end()) { | |||
| group_info.emplace(hash, info); | |||
| } | |||
| groups[hash].push_back(opr); | |||
| return hash; | |||
| } | |||
| void PackAllReduceReplacePass::divide_packs( | |||
| const ThinHashMap<uint64_t, cg::OprNodeArray>& groups, | |||
| ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs, | |||
| size_t max_size) { | |||
| cg::OprNodeArray pack; | |||
| size_t sum = 0; | |||
| for (auto it : groups) { | |||
| uint64_t hash = it.first; | |||
| const cg::OprNodeArray& group = it.second; | |||
| for (size_t i = 0; i < group.size(); i++) { | |||
| OperatorNodeBase* opr = group[i]; | |||
| VarNode* var = opr->input(0); | |||
| const TensorShape* shape = var->owner_graph() | |||
| ->static_infer_manager().infer_shape_fallible(var); | |||
| if (shape == nullptr) continue; | |||
| pack.push_back(opr); | |||
| sum += var->dtype().size(shape->total_nr_elems()); | |||
| if (sum >= max_size) { | |||
| if (pack.size() > 1) packs[hash].push_back(pack); | |||
| pack.clear(); | |||
| sum = 0; | |||
| } | |||
| } | |||
| if (pack.size() > 1) packs[hash].push_back(pack); | |||
| pack.clear(); | |||
| sum = 0; | |||
| } | |||
| } | |||
| void PackAllReduceReplacePass::insert_packed_oprs( | |||
| size_t pack_id, | |||
| const cg::OprNodeArray& pack, | |||
| std::shared_ptr<GroupInfo> info, | |||
| ThinHashMap<VarNode*, VarNode*>& replace_map, int priority) { | |||
| // set priority | |||
| mgb_assert(pack.size() > 0); | |||
| auto graph = pack[0]->owner_graph(); | |||
| auto on_opr_inserted = [priority] (const cg::event::OprInserted& event) { | |||
| event.opr->node_prop().attribute().priority = priority; | |||
| }; | |||
| auto handler = graph->event().register_receiver<cg::event::OprInserted>(on_opr_inserted); | |||
| // flatten inputs and record shapes and partition | |||
| std::vector<SymbolVar> shapes; | |||
| SymbolVarArray flattens; | |||
| SymbolVarArray partition; | |||
| for (size_t i = 0; i < pack.size(); i++) { | |||
| VarNode* var = pack[i]->input(0); | |||
| auto shape = opr::GetVarShape::make(SymbolVar(var)); | |||
| shapes.push_back(shape); | |||
| SymbolVar flatten = SymbolVar(var).flatten(); | |||
| flattens.push_back(flatten); | |||
| partition.push_back(opr::Reduce::make(shape, {opr::Reduce::Mode::PRODUCT, 0})); | |||
| } | |||
| // concat | |||
| SymbolVar concat = opr::Concat::make(flattens, 0); | |||
| // allreduce | |||
| std::string key = ssprintf("grad_pack_%zu", pack_id); | |||
| auto param = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||
| SymbolVar allreduce = opr::CollectiveComm::make({concat}, graph, | |||
| key, info->nr_devices, info->is_root, info->rank, | |||
| info->group_client, param, info->dtype, info->backend)[0]; | |||
| // split according to recorded partition | |||
| SymbolVarArray splits = opr::Split::make(allreduce, | |||
| opr::Split::Options::make_partition(0, partition)); | |||
| // reshape and insert results into replace_map | |||
| mgb_assert(pack.size() == splits.size()); | |||
| for (size_t i = 0; i < pack.size(); i++) { | |||
| VarNode* reshape = splits[i].reshape(shapes[i]).node(); | |||
| replace_map[pack[i]->output(0)] = reshape; | |||
| } | |||
| } | |||
| void PackAllReduceReplacePass::apply(OptState& opt) const { | |||
| // get graph options | |||
| auto comp_graph = opt.graph().comp_graph(); | |||
| size_t max_size = comp_graph->options().allreduce_pack_max_size * 1024 * 1024; | |||
| size_t ignore_first = comp_graph->options().allreduce_pack_ignore_first; | |||
| if (max_size == 0) return; | |||
| // get topo order | |||
| auto& topo_sorter = static_cast<cg::ComputingGraphImpl*>(comp_graph)->topo_sorter(); | |||
| cg::CompSeqExtraInfo extra_info; | |||
| VarNodeArray endpoints = to_var_node_array(opt.graph().endpoint_vars()); | |||
| const cg::OprNodeArray* seq = topo_sorter.get_comp_seq(extra_info, endpoints); | |||
| topo_sorter.restore_opr_prop(); | |||
| // collect allreduce groups from topo sequence | |||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | |||
| ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||
| for (size_t i = 0; i < seq->size(); i++) { | |||
| if (seq->at(i)->same_type<opr::CollectiveComm>()) { | |||
| // ignore the first several allreduces | |||
| if (ignore_first > 0) { | |||
| --ignore_first; | |||
| } else { | |||
| collect_groups(seq->at(i), group_info, groups); | |||
| } | |||
| } | |||
| } | |||
| // divide groups into packs | |||
| ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>> packs; | |||
| divide_packs(groups, packs, max_size); | |||
| // make sure that oprs inserted in this pass (reshape, concat, allreduce, | |||
| // split, reshape) have higher priority than existing operators | |||
| int priority = -seq->size() - 100; | |||
| // insert packed operators and generate replace_map | |||
| ThinHashMap<VarNode*, VarNode*> replace_map; | |||
| size_t pack_id = 0; | |||
| for (auto it : packs) { | |||
| uint64_t hash = it.first; | |||
| for (auto pack : it.second) { | |||
| opt.call_with_opr(pack[0], [&]() { | |||
| insert_packed_oprs(pack_id, pack, group_info[hash], replace_map, priority); | |||
| }, OprPropertyFlag::NONE); | |||
| pack_id += 1; | |||
| } | |||
| } | |||
| // replace vars | |||
| auto rewriter = opt.graph().make_rewriter(); | |||
| auto cb_replace = [&](OperatorNodeBase* opr) { | |||
| for (auto i : opr->input()) { | |||
| auto iter = replace_map.find(i); | |||
| if (iter != replace_map.end()) { | |||
| rewriter.replace_var(i, iter->second, nullptr); | |||
| } | |||
| } | |||
| rewriter.auto_replace_outputs(opr); | |||
| }; | |||
| opt.graph().iter(cb_replace); | |||
| rewriter.apply_inplace(); | |||
| } | |||
| #else | |||
| /* ======================= PackAllReduceScanPass ====================== */ | |||
| const char* PackAllReduceScanPass::name() const { | |||
| return "pack_allreduce_scan"; | |||
| } | |||
| void PackAllReduceScanPass::apply(OptState& opt) const { | |||
| } | |||
| bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { | |||
| return true; | |||
| } | |||
| /* ======================= PackAllReduceReplacePass ====================== */ | |||
| const char* PackAllReduceReplacePass::name() const { | |||
| return "pack_allreduce_replace"; | |||
| } | |||
| void PackAllReduceReplacePass::apply(OptState& opt) const {} | |||
| uint64_t PackAllReduceReplacePass::collect_groups( | |||
| OperatorNodeBase* opr, | |||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info, | |||
| ThinHashMap<uint64_t, cg::OprNodeArray>& groups) { | |||
| return 0; | |||
| } | |||
| void PackAllReduceReplacePass::divide_packs( | |||
| const ThinHashMap<uint64_t, cg::OprNodeArray>& groups, | |||
| ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs, | |||
| size_t max_size) { | |||
| } | |||
| void PackAllReduceReplacePass::insert_packed_oprs( | |||
| size_t pack_id, | |||
| const cg::OprNodeArray& pack, | |||
| std::shared_ptr<GroupInfo> info, | |||
| ThinHashMap<VarNode*, VarNode*>& replace_map, int priority) { | |||
| } | |||
| #endif // MGB_ENABLE_OPR_MM | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -11,6 +11,8 @@ | |||
| #pragma once | |||
| #include <vector> | |||
| #include "megbrain/gopt/framework.h" | |||
| namespace mgb { | |||
| @@ -90,6 +92,45 @@ namespace gopt { | |||
| void apply(OptState& opt) const override; | |||
| }; | |||
| //! scan allreduces of param grads | |||
| class PackAllReduceScanPass final : public Pass { | |||
| public: | |||
| const char* name() const override; | |||
| void apply(OptState& opt) const override; | |||
| private: | |||
| // check pattern param -> grad -> allreduce | |||
| static bool check_pattern(OperatorNodeBase* opr); | |||
| }; | |||
| //! pack allreduces of param grads | |||
| class PackAllReduceReplacePass final : public Pass { | |||
| public: | |||
| class GroupInfo; | |||
| const char* name() const override; | |||
| void apply(OptState& opt) const override; | |||
| // collect allreduces and divide into groups | |||
| static uint64_t collect_groups( | |||
| OperatorNodeBase* opr, | |||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>>& group_info, | |||
| ThinHashMap<uint64_t, cg::OprNodeArray>& groups); | |||
| // divide groups into packs, max_size in MB | |||
| static void divide_packs( | |||
| const ThinHashMap<uint64_t, cg::OprNodeArray>& groups, | |||
| ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>>& packs, | |||
| size_t max_size); | |||
| // insert packed operators and update replace_map | |||
| static void insert_packed_oprs( | |||
| size_t pack_id, | |||
| const cg::OprNodeArray& pack, | |||
| std::shared_ptr<GroupInfo> info, | |||
| ThinHashMap<VarNode*, VarNode*>& replace_map, int priority); | |||
| }; | |||
| } // namespace gopt | |||
| } // namespace mgb | |||
| @@ -14,6 +14,7 @@ | |||
| #include "megbrain/gopt/basic_arith.h" | |||
| #include "megbrain/gopt/misc.h" | |||
| #include "megbrain/opr/basic_arith_wrapper.h" | |||
| #include "megbrain/opr/blas.h" | |||
| #include "megbrain/opr/cond.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/opr/utility.h" | |||
| @@ -410,4 +411,322 @@ TEST_PASS(RemoveRedundantTypeCvtPass, Basic) { | |||
| check(x_q8_q8, x_q8_fp32_q8_); | |||
| } | |||
| #if MGB_ENABLE_OPR_MM | |||
| #include "megbrain/opr/collective_comm.h" | |||
| #include "../../opr-mm/test/mock_client.h" | |||
| TEST_PASS(PackAllReduceScanPass, Basic) { | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().allreduce_pack_max_size = 5000; | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto cn = CompNode::load("gpux"); | |||
| auto dev_x0 = std::make_shared<DeviceTensorND>(cn, TensorShape{3, 5}); | |||
| auto dev_x1 = std::make_shared<DeviceTensorND>(cn, TensorShape{4, 6}); | |||
| auto dev_y0 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}); | |||
| auto dev_y1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}); | |||
| auto x0 = opr::SharedDeviceTensor::make(*graph, dev_x0); | |||
| auto x1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_x1); | |||
| auto y0 = opr::SharedDeviceTensor::make(*graph, dev_y0); | |||
| auto y1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_y1); | |||
| auto grad0 = opr::VirtualGrad::make(y0, x0); | |||
| auto grad1 = opr::VirtualGrad::make(y0, x1); | |||
| auto grad2 = opr::VirtualGrad::make(y1, x0); | |||
| auto grad3 = opr::VirtualGrad::make(y1, x1); | |||
| auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||
| auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), | |||
| "grad0", 2, 0, 0, client, mode)[0]; | |||
| auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), | |||
| "grad1", 2, 0, 0, client, mode)[0]; | |||
| auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), | |||
| "grad2", 2, 0, 0, client, mode)[0]; | |||
| auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), | |||
| "grad3", 2, 0, 0, client, mode)[0]; | |||
| gopt::GraphOptimizer() | |||
| .add_pass<gopt::PackAllReduceScanPass>() | |||
| .apply({{comm0, comm1, comm2, comm3}}); | |||
| auto get_hash = [] (const SymbolVar& symvar) { | |||
| cg::OperatorNodeBase* opr = symvar.node()->owner_opr(); | |||
| return opr->cast_final_safe<opr::CollectiveComm>().pack_hash(); | |||
| }; | |||
| uint64_t hash0 = get_hash(comm0); | |||
| uint64_t hash1 = get_hash(comm1); | |||
| uint64_t hash2 = get_hash(comm2); | |||
| uint64_t hash3 = get_hash(comm3); | |||
| ASSERT_EQ(hash0, hash1); | |||
| ASSERT_EQ(hash2, hash3); | |||
| ASSERT_NE(hash0, hash2); | |||
| } | |||
| TEST_PASS(PackAllReduceReplacePass, CollectGroups) { | |||
| REQUIRE_GPU(2); | |||
| auto cns = load_multiple_xpus(2); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 2; | |||
| auto cli0 = std::make_shared<test::MockGroupClient>("mock_addr0"); | |||
| auto cli1 = std::make_shared<test::MockGroupClient>("mock_addr1"); | |||
| using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; | |||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | |||
| ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||
| auto add_opr = [&] (const CompNode& cn, TensorShape shape, const DType& dt, | |||
| std::shared_ptr<test::MockGroupClient> client, uint64_t extra_hash) { | |||
| auto dev0 = std::make_shared<DeviceTensorND>(cn, shape, dt); | |||
| auto wrt = opr::SharedDeviceTensor::make(*graph, dev0); | |||
| auto dev1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}, dt); | |||
| auto target = opr::SharedDeviceTensor::make(*graph, dev1); | |||
| auto grad = opr::VirtualGrad::make(target, wrt); | |||
| auto comm = opr::CollectiveComm::make( | |||
| {grad}, graph.get(), "key", 2, 0, 0, client, | |||
| opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0] | |||
| .node()->owner_opr(); | |||
| comm->cast_final_safe<opr::CollectiveComm>().set_pack_hash(extra_hash); | |||
| return gopt::PackAllReduceReplacePass::collect_groups(comm, group_info, groups); | |||
| }; | |||
| uint64_t hash0 = add_opr(cns[0], TensorShape{1, 3}, dtype::Float32{}, cli0, 1); | |||
| uint64_t hash1 = add_opr(cns[0], TensorShape{2, 4}, dtype::Float32{}, cli0, 1); // same | |||
| uint64_t hash2 = add_opr(cns[1], TensorShape{3, 5}, dtype::Float32{}, cli0, 1); // comp_node | |||
| uint64_t hash3 = add_opr(cns[0], TensorShape{4, 6}, dtype::Float16{}, cli0, 1); // dtype | |||
| uint64_t hash4 = add_opr(cns[0], TensorShape{5, 7}, dtype::Float32{}, cli1, 1); // client | |||
| uint64_t hash5 = add_opr(cns[0], TensorShape{6, 8}, dtype::Float32{}, cli0, 2); // extra_hash | |||
| ASSERT_EQ(hash0, hash1); | |||
| std::set<uint64_t> s; | |||
| s.insert(hash0); | |||
| s.insert(hash1); | |||
| s.insert(hash2); | |||
| s.insert(hash3); | |||
| s.insert(hash4); | |||
| s.insert(hash5); | |||
| ASSERT_EQ(5, s.size()); | |||
| ASSERT_EQ(1, group_info.count(hash0)); | |||
| ASSERT_EQ(1, group_info.count(hash1)); | |||
| ASSERT_EQ(1, group_info.count(hash2)); | |||
| ASSERT_EQ(1, group_info.count(hash3)); | |||
| ASSERT_EQ(1, group_info.count(hash4)); | |||
| ASSERT_EQ(1, group_info.count(hash5)); | |||
| ASSERT_EQ(2, groups[hash0].size()); | |||
| ASSERT_EQ(2, groups[hash1].size()); | |||
| ASSERT_EQ(1, groups[hash2].size()); | |||
| ASSERT_EQ(1, groups[hash3].size()); | |||
| ASSERT_EQ(1, groups[hash4].size()); | |||
| ASSERT_EQ(1, groups[hash5].size()); | |||
| } | |||
| TEST_PASS(PackAllReduceReplacePass, DividePacks) { | |||
| auto cn = CompNode::load("gpux"); | |||
| auto graph = ComputingGraph::make(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||
| ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||
| ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>> packs; | |||
| auto insert_opr = [&] (size_t size) { | |||
| auto dev = std::make_shared<DeviceTensorND>(cn, TensorShape{size / sizeof(float)}); | |||
| auto sd = opr::SharedDeviceTensor::make(*graph, dev); | |||
| auto symvar = opr::CollectiveComm::make({sd}, graph.get(), | |||
| "key", 2, 0, 0, client, mode)[0]; | |||
| auto opr = symvar.node()->owner_opr(); | |||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||
| comm.set_pack_hash(1); | |||
| return opr; | |||
| }; | |||
| auto pack_size = [&] (cg::OprNodeArray& pack) { | |||
| size_t sum = 0; | |||
| for (size_t i = 0; i < pack.size(); i++) { | |||
| auto var = pack[i]->input(0); | |||
| sum += var->dtype().size(var->shape().total_nr_elems()); | |||
| } | |||
| return sum; | |||
| }; | |||
| groups[0].push_back(insert_opr(100)); // group0, pack0, size=1100 | |||
| groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100 | |||
| groups[0].push_back(insert_opr(400)); // group0, pack0, size=1100 | |||
| groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100 | |||
| groups[0].push_back(insert_opr(500)); // group0, pack1, size=800 | |||
| groups[0].push_back(insert_opr(200)); // group0, pack1, size=800 | |||
| groups[0].push_back(insert_opr(100)); // group0, pack1, size=800 | |||
| groups[1].push_back(insert_opr(100)); // group1, pack0, size=900 | |||
| groups[1].push_back(insert_opr(400)); // group1, pack0, size=900 | |||
| groups[1].push_back(insert_opr(300)); // group1, pack0, size=900 | |||
| groups[1].push_back(insert_opr(100)); // group1, pack0, size=900 | |||
| gopt::PackAllReduceReplacePass::divide_packs(groups, packs, 1000); | |||
| ASSERT_EQ(2, packs.size()); | |||
| ASSERT_EQ(2, packs[0].size()); | |||
| ASSERT_EQ(4, packs[0][0].size()); | |||
| ASSERT_EQ(1100, pack_size(packs[0][0])); | |||
| ASSERT_EQ(3, packs[0][1].size()); | |||
| ASSERT_EQ(800, pack_size(packs[0][1])); | |||
| ASSERT_EQ(1, packs[1].size()); | |||
| ASSERT_EQ(4, packs[1][0].size()); | |||
| ASSERT_EQ(900, pack_size(packs[1][0])); | |||
| } | |||
| TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { | |||
| auto cn = CompNode::load("gpux"); | |||
| auto graph = ComputingGraph::make(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||
| size_t nr_devices = 2; | |||
| uint32_t rank = 0; | |||
| uint32_t root = 0; | |||
| using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; | |||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | |||
| ThinHashMap<uint64_t, cg::OprNodeArray> groups; | |||
| auto insert_opr = [&] (const TensorShape& shape) { | |||
| auto dev = std::make_shared<DeviceTensorND>(cn, shape); | |||
| auto sd = opr::SharedDeviceTensor::make(*graph, dev); | |||
| auto symvar = opr::CollectiveComm::make({sd}, graph.get(), | |||
| "key", nr_devices, rank, root, client, mode)[0]; | |||
| auto opr = symvar.node()->owner_opr(); | |||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||
| comm.set_pack_hash(1); | |||
| gopt::PackAllReduceReplacePass::collect_groups(opr, group_info, groups); | |||
| return symvar; | |||
| }; | |||
| auto shape_x = TensorShape{100, 200}; | |||
| auto shape_y = TensorShape{200, 400}; | |||
| auto x = insert_opr(shape_x); | |||
| auto y = insert_opr(shape_y); | |||
| ASSERT_EQ(1, group_info.size()); | |||
| ASSERT_EQ(1, groups.size()); | |||
| auto info = group_info.begin()->second; | |||
| auto pack = groups.begin()->second; | |||
| size_t pack_id = 0; | |||
| ThinHashMap<VarNode*, VarNode*> replace_map; | |||
| gopt::PackAllReduceReplacePass::insert_packed_oprs(pack_id, pack, info, replace_map, -1); | |||
| auto grad_x = SymbolVar(x.node()->owner_opr()->input(0)); | |||
| auto grad_y = SymbolVar(y.node()->owner_opr()->input(0)); | |||
| auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0); | |||
| std::string key = ssprintf("grad_pack_%zu", pack_id); | |||
| auto allreduce = opr::CollectiveComm::make({concat}, graph.get(), | |||
| key, nr_devices, rank, root, client, mode)[0]; | |||
| std::vector<size_t> partition; | |||
| partition.push_back(shape_x.total_nr_elems()); | |||
| partition.push_back(shape_y.total_nr_elems()); | |||
| auto splits = opr::Split::make(allreduce, | |||
| opr::Split::Options::make_partition(allreduce, 0, partition)); | |||
| ASSERT_EQ(2, splits.size()); | |||
| auto dest_x = splits[0].reshape(shape_x); | |||
| auto dest_y = splits[1].reshape(shape_y); | |||
| ASSERT_EQ(2, replace_map.size()); | |||
| ASSERT_TRUE(replace_map.count(x.node()) > 0); | |||
| ASSERT_EQ(replace_map.at(x.node()), dest_x.node()); | |||
| ASSERT_TRUE(replace_map.count(y.node()) > 0); | |||
| ASSERT_EQ(replace_map.at(y.node()), dest_y.node()); | |||
| } | |||
| TEST_PASS(PackAllReduceReplacePass, Equivalence) { | |||
| REQUIRE_GPU(2); | |||
| auto cns = load_multiple_xpus(2); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto build_graph = [&] (uint32_t rank, std::shared_ptr<ComputingGraph> graph, | |||
| SymbolVarArray& array) { | |||
| HostTensorGenerator<> gen; | |||
| auto cn = cns[rank]; | |||
| auto host_x = gen({1, 1000}); | |||
| auto host_y = gen({1000, 1}); | |||
| auto dev_x = std::make_shared<DeviceTensorND>(cn); | |||
| auto dev_y = std::make_shared<DeviceTensorND>(cn); | |||
| dev_x->copy_from(*host_x).sync(); | |||
| dev_y->copy_from(*host_y).sync(); | |||
| auto x = opr::SharedDeviceTensor::make(*graph, dev_x); | |||
| auto y = opr::VolatileSharedDeviceTensor::make(*graph, dev_y); | |||
| auto loss = opr::MatrixMul::make(x, y).flatten(); | |||
| auto grad_x = opr::VirtualGrad::make(loss, x); | |||
| auto grad_y = opr::VirtualGrad::make(loss, y); | |||
| using Mode = opr::CollectiveComm::Param::Mode; | |||
| bool is_root = (rank == 0); | |||
| auto reduced_x = opr::CollectiveComm::make({grad_x}, graph.get(), | |||
| "x", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; | |||
| auto reduced_y = opr::CollectiveComm::make({grad_y}, graph.get(), | |||
| "y", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; | |||
| graph->options().allreduce_pack_max_size = 5000; | |||
| graph->options().allreduce_pack_ignore_first = 0; | |||
| auto dest_vars = gopt::GraphOptimizer{} | |||
| .add_pass<gopt::PackAllReduceScanPass>() | |||
| .add_pass<gopt::PackAllReduceReplacePass>() | |||
| .apply({{reduced_x, reduced_y}}).endpoint_vars(); | |||
| array.emplace_back(reduced_x); | |||
| array.emplace_back(reduced_y); | |||
| array.emplace_back(dest_vars[0]); | |||
| array.emplace_back(dest_vars[1]); | |||
| }; | |||
| auto run = [&] (uint32_t rank) { | |||
| auto graph = ComputingGraph::make(); | |||
| SymbolVarArray array; | |||
| build_graph(rank, graph, array); | |||
| HostTensorND host_reduced_x, host_reduced_y, host_dest_0, host_dest_1; | |||
| graph->options().allreduce_pack_max_size = 0; | |||
| auto func = graph->compile({make_callback_copy(array[0], host_reduced_x), | |||
| make_callback_copy(array[1], host_reduced_y), | |||
| make_callback_copy(array[2], host_dest_0), | |||
| make_callback_copy(array[3], host_dest_1)}); | |||
| func->execute(); | |||
| MGB_ASSERT_TENSOR_EQ(host_reduced_x, host_dest_0); | |||
| MGB_ASSERT_TENSOR_EQ(host_reduced_y, host_dest_1); | |||
| }; | |||
| std::thread t0(run, 0); | |||
| std::thread t1(run, 1); | |||
| t0.join(); | |||
| t1.join(); | |||
| } | |||
| #endif // MGB_ENABLE_OPR_MM | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -461,16 +461,7 @@ void CollectiveComm::opr_register() { | |||
| m_rank = reg_info.rank; | |||
| m_root = reg_info.root_rank; | |||
| MegRayCommunicatorBuilder* builder; | |||
| { | |||
| static std::mutex user_data_mtx; | |||
| std::unique_lock<std::mutex> lk(user_data_mtx); | |||
| builder = owner_graph()->options().user_data | |||
| .get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||
| } | |||
| m_megray_comm = builder->get_megray_comm( | |||
| m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
| reg_info.hash, m_key, m_nr_devices, m_rank, | |||
| get_megray_backend(m_backend), m_group_client); | |||
| @@ -736,13 +727,15 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm( | |||
| const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||
| const OperatorNodeConfig& config) { | |||
| auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>(); | |||
| return opr::CollectiveComm::make( | |||
| auto new_opr = CollectiveComm::make( | |||
| to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), | |||
| opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), | |||
| opr.group_client(), opr.dev_buffers(), opr.param(), | |||
| opr.dtype(), opr.backend(), config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| new_opr->cast_final_safe<opr::CollectiveComm>().set_pack_hash(opr.pack_hash()); | |||
| return new_opr; | |||
| } | |||
| MGB_REG_OPR_SHALLOW_COPY(CollectiveComm, opr_shallow_copy_collective_mm); | |||
| @@ -54,13 +54,7 @@ void RemoteSend::scn_do_execute() { | |||
| auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, | |||
| comp_node.get_uid()); | |||
| auto megray_comm_builder = | |||
| owner_graph() | |||
| ->options() | |||
| .user_data | |||
| .get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||
| m_megray_comm = megray_comm_builder->get_megray_comm( | |||
| m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
| reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | |||
| m_init = true; | |||
| } | |||
| @@ -158,13 +152,7 @@ void RemoteRecv::scn_do_execute() { | |||
| m_peer.key, 2, false, 1, | |||
| comp_node.get_uid()); | |||
| auto megray_comm_builder = | |||
| owner_graph() | |||
| ->options() | |||
| .user_data | |||
| .get_user_data_or_create<MegRayCommunicatorBuilder>(); | |||
| m_megray_comm = megray_comm_builder->get_megray_comm( | |||
| m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
| reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | |||
| m_init = true; | |||
| } | |||
| @@ -14,8 +14,8 @@ | |||
| using namespace mgb; | |||
| using namespace opr; | |||
| bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | |||
| std::unique_lock<std::mutex> lk(m_mtx); | |||
| bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { | |||
| std::unique_lock<std::mutex> lk(m_map_mtx); | |||
| auto it = m_megray_comms.find(hash); | |||
| if (it != m_megray_comms.end()) { | |||
| comm = it->second; | |||
| @@ -24,27 +24,37 @@ bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Comm | |||
| return false; | |||
| } | |||
| void MegRayCommunicatorBuilder::emplace(uint64_t hash, | |||
| void MegRayCommBuilder::emplace(uint64_t hash, | |||
| std::shared_ptr<MegRay::Communicator> comm) { | |||
| std::unique_lock<std::mutex> lk(m_mtx); | |||
| std::unique_lock<std::mutex> lk(m_map_mtx); | |||
| m_megray_comms.emplace(hash, comm); | |||
| } | |||
| std::shared_ptr<MegRay::Communicator> MegRayCommunicatorBuilder::get_megray_comm( | |||
| std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||
| uint64_t hash, std::string key, uint32_t size, uint32_t rank, | |||
| MegRay::Backend backend, | |||
| std::shared_ptr<mgb::opr::GroupClient> group_client) { | |||
| { | |||
| // singleton pattern | |||
| std::unique_lock<std::mutex> lk(sm_instance_mtx); | |||
| if (sm_instance == nullptr) { | |||
| sm_instance = new MegRayCommBuilder(); | |||
| } | |||
| } | |||
| std::shared_ptr<MegRay::Communicator> comm; | |||
| if (!find(hash, comm)) { | |||
| if (!sm_instance->find(hash, comm)) { | |||
| comm = MegRay::get_communicator(size, rank, backend); | |||
| auto uid = comm->get_uid(); | |||
| auto uids = group_client->gather_uid(uid, key, size, rank); | |||
| mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK); | |||
| emplace(hash, comm); | |||
| sm_instance->emplace(hash, comm); | |||
| } | |||
| return comm; | |||
| } | |||
| MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder); | |||
| MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr; | |||
| std::mutex MegRayCommBuilder::sm_instance_mtx; | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -81,6 +81,10 @@ public: | |||
| return m_group_client; | |||
| } | |||
| void set_pack_hash(uint64_t hash) { m_pack_hash = hash; } | |||
| uint64_t pack_hash() const { return m_pack_hash; } | |||
| std::shared_ptr<MegRay::Context> megray_ctx() const { | |||
| return m_megray_ctx; | |||
| } | |||
| @@ -123,6 +127,9 @@ private: | |||
| // whose shape infer should be disabled *during* static infer phase. | |||
| bool m_enable_shape_infer = false; | |||
| //! set in PackAllReduceScanPass and used in PackAllReduceReplacePass | |||
| uint64_t m_pack_hash = 0; | |||
| std::shared_ptr<MegRay::Context> m_megray_ctx; | |||
| std::shared_ptr<MegRay::Communicator> m_megray_comm; | |||
| bool m_init = false; | |||
| @@ -126,6 +126,8 @@ class GroupClient { | |||
| virtual ~GroupClient() = default; | |||
| public: | |||
| virtual const std::string& get_addr() const = 0; | |||
| virtual GroupManager::RegisterInfo opr_register(const std::string& key, | |||
| size_t nr_devices, | |||
| bool is_root, int rank, | |||
| @@ -23,18 +23,19 @@ namespace opr { | |||
| /*! | |||
| * gather MegRay unique ids and build communicator, use hash for deduplication | |||
| */ | |||
| class MegRayCommunicatorBuilder final : public mgb::UserDataContainer::UserData { | |||
| MGB_TYPEINFO_OBJ_DECL; | |||
| class MegRayCommBuilder { | |||
| private: | |||
| bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm); | |||
| void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | |||
| std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms; | |||
| std::mutex m_mtx; | |||
| std::mutex m_map_mtx; | |||
| static MegRayCommBuilder* sm_instance; | |||
| static std::mutex sm_instance_mtx; | |||
| public: | |||
| std::shared_ptr<MegRay::Communicator> get_megray_comm( | |||
| static std::shared_ptr<MegRay::Communicator> get_megray_comm( | |||
| uint64_t hash, std::string key, uint32_t size, uint32_t rank, | |||
| MegRay::Backend backend, | |||
| std::shared_ptr<mgb::opr::GroupClient> group_client); | |||
| @@ -47,7 +47,7 @@ public: | |||
| uint32_t group_barrier(uint32_t size, uint32_t rank) override; | |||
| const std::string& get_addr() const { | |||
| const std::string& get_addr() const override { | |||
| return m_addr; | |||
| } | |||
| @@ -17,11 +17,10 @@ | |||
| #include "megbrain/opr/utility.h" | |||
| #include "megbrain/test/helper.h" | |||
| #include "megbrain/graph.h" | |||
| #include "mock_client.h" | |||
| using namespace mgb; | |||
| namespace { | |||
| using Mode = opr::CollectiveComm::Param::Mode; | |||
| SymbolVar make_all_reduce_output(const Mode mode, | |||
| @@ -41,41 +40,6 @@ SymbolVarArray make_reduce_scatter_sum_output(const SymbolVarArray& inputs) { | |||
| rdc, opr::Split::Options::make_average(0, inputs.size())); | |||
| } | |||
| class MockGroupClient final : public opr::GroupClient { | |||
| public: | |||
| ~MockGroupClient() override = default; | |||
| opr::GroupManager::RegisterInfo opr_register(const std::string& key, | |||
| size_t nr_devices, | |||
| bool is_root, int rank, | |||
| uintptr_t stream) { | |||
| return m_mgr.opr_register(key, nr_devices, is_root, rank, stream); | |||
| } | |||
| std::vector<std::string> gather_uid(const std::string& uid, | |||
| const std::string& key, uint32_t size, uint32_t rank) { | |||
| return m_mgr.gather_uid(uid, key, size, rank); | |||
| } | |||
| void set_output_shape(const std::string& key, | |||
| const TensorShape& shape) override { | |||
| m_mgr.set_output_shape(key, shape); | |||
| } | |||
| TensorShape get_output_shape(const std::string& key) override { | |||
| return m_mgr.get_output_shape(key); | |||
| } | |||
| uint32_t group_barrier(uint32_t size, uint32_t rank) override { | |||
| return m_mgr.group_barrier(size, rank); | |||
| } | |||
| private: | |||
| opr::GroupManager m_mgr; | |||
| }; | |||
| } // namespace | |||
| TEST(TestOprCollectiveComm, AllReduce) { | |||
| REQUIRE_GPU(2); | |||
| @@ -88,7 +52,7 @@ TEST(TestOprCollectiveComm, AllReduce) { | |||
| auto host_x1 = gen({28, 28}); | |||
| HostTensorND host_y0, host_y1, host_y_expect; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | |||
| @@ -126,7 +90,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { | |||
| auto host_x1 = gen({28, 28}); | |||
| HostTensorND host_y0, host_y1, host_y_expect; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto run_0 = [&]() { | |||
| auto graph0 = ComputingGraph::make(); | |||
| @@ -187,7 +151,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { | |||
| HostTensorND host_y0, host_y1, host_y_expect; | |||
| HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto run_0 = [&]() { // rank 0 | |||
| auto graph0 = ComputingGraph::make(); | |||
| @@ -268,7 +232,7 @@ TEST(TestOprCollectiveComm, AllGather) { | |||
| auto host_x1 = gen({28, 28}); | |||
| HostTensorND host_y0, host_y1, host_y_expect; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | |||
| @@ -300,7 +264,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) { | |||
| auto host_x1 = gen({28, 28}); | |||
| HostTensorND host_y0, host_y1, host_y_expect; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto run_0 = [&]() { // rank 0 | |||
| auto graph0 = ComputingGraph::make(); | |||
| @@ -356,7 +320,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { | |||
| HostTensorND host_out_grad0, host_out_grad1; | |||
| HostTensorND host_out_grad0_expect, host_out_grad1_expect; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto run_0 = [&]() { // rank 0 | |||
| auto graph0 = ComputingGraph::make(); | |||
| @@ -438,7 +402,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSum) { | |||
| auto host_x1 = gen({28, 28}); | |||
| HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | |||
| @@ -471,7 +435,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) { | |||
| auto host_x1 = gen({8}); | |||
| HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto run_0 = [&]() { // rank 0 | |||
| auto graph0 = ComputingGraph::make(); | |||
| @@ -528,7 +492,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { | |||
| HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; | |||
| HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto run_0 = [&]() { // rank 0 | |||
| auto graph0 = ComputingGraph::make(); | |||
| @@ -610,7 +574,7 @@ TEST(TestOprCollectiveComm, ReduceSum) { | |||
| auto host_x1 = gen({28, 28}); | |||
| HostTensorND host_y0, host_y1, host_y_expect; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | |||
| @@ -641,7 +605,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) { | |||
| auto host_x1 = gen({28, 28}); | |||
| HostTensorND host_y0, host_y_expect; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto run_0 = [&]() { // rank 0 | |||
| auto graph0 = ComputingGraph::make(); | |||
| @@ -694,7 +658,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { | |||
| HostTensorND host_y0, host_y0_expect, host_out_grad0, host_out_grad1; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto run_0 = [&]() { // rank 0 | |||
| auto graph0 = ComputingGraph::make(); | |||
| @@ -764,7 +728,7 @@ TEST(TestOprCollectiveComm, Broadcast) { | |||
| auto host_x0 = gen({28, 28}); | |||
| HostTensorND host_y0, host_y1, host_y_expect; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); | |||
| @@ -794,7 +758,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) { | |||
| auto host_x0 = gen({28, 28}); | |||
| HostTensorND host_y0, host_y1; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto run_0 = [&]() { // rank 0 | |||
| auto graph0 = ComputingGraph::make(); | |||
| @@ -840,7 +804,7 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { | |||
| HostTensorND host_y0, host_y1, host_out_grad, host_out_grad_expect; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto run_0 = [&]() { // rank 0 | |||
| auto graph0 = ComputingGraph::make(); | |||
| @@ -14,51 +14,14 @@ | |||
| #include "megbrain/opr/utility.h" | |||
| #include "megbrain/system.h" | |||
| #include "megbrain/test/helper.h" | |||
| #include "mock_client.h" | |||
| #include <thread> | |||
| using namespace mgb; | |||
| using namespace opr; | |||
| namespace { | |||
| class MockGroupClient final : public opr::GroupClient { | |||
| public: | |||
| ~MockGroupClient() override = default; | |||
| opr::GroupManager::RegisterInfo opr_register(const std::string& key, | |||
| size_t nr_devices, | |||
| bool is_root, int rank, | |||
| uint64_t comp_node_hash) { | |||
| return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); | |||
| } | |||
| std::vector<std::string> gather_uid(const std::string& uid, | |||
| const std::string& key, uint32_t size, uint32_t rank) { | |||
| return m_mgr.gather_uid(uid, key, size, rank); | |||
| } | |||
| void set_output_shape(const std::string& key, | |||
| const TensorShape& shape) override { | |||
| m_mgr.set_output_shape(key, shape); | |||
| } | |||
| TensorShape get_output_shape(const std::string& key) override { | |||
| return m_mgr.get_output_shape(key); | |||
| } | |||
| uint32_t group_barrier(uint32_t size, uint32_t rank) override { | |||
| return m_mgr.group_barrier(size, rank); | |||
| } | |||
| private: | |||
| opr::GroupManager m_mgr; | |||
| }; | |||
| const auto send_tag = RemoteIOBase::Type::SEND; | |||
| const auto recv_tag = RemoteIOBase::Type::RECV; | |||
| } // anonymous namespace | |||
| const auto send_tag = opr::RemoteIOBase::Type::SEND; | |||
| const auto recv_tag = opr::RemoteIOBase::Type::RECV; | |||
| TEST(TestOprIORemote, Identity) { | |||
| REQUIRE_GPU(2); | |||
| @@ -69,7 +32,7 @@ TEST(TestOprIORemote, Identity) { | |||
| auto host_x = gen({28, 28}); | |||
| HostTensorND host_y; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | |||
| @@ -90,7 +53,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||
| HostTensorGenerator<> gen; | |||
| auto host_x = gen({2, 3}, cns[1]); | |||
| HostTensorND host_x_get; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto sender = [&]() { | |||
| auto graph = ComputingGraph::make(); | |||
| @@ -123,7 +86,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||
| HostTensorGenerator<> gen; | |||
| auto host_x = gen({2, 3}, cns[0]); | |||
| HostTensorND host_x_get; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto sender = [&]() { | |||
| sys::set_thread_name("sender"); | |||
| @@ -157,7 +120,7 @@ TEST(TestOprIORemote, APlusB) { | |||
| HostTensorGenerator<> gen; | |||
| auto host_x = gen({5, 7}, cns[0]), host_y = gen({5, 1}, cns[0]); | |||
| HostTensorND host_z; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto sender = [&]() { | |||
| auto graph = ComputingGraph::make(); | |||
| @@ -208,7 +171,7 @@ TEST(TestOprIORemote, SendGrad) { | |||
| HostTensorGenerator<> gen; | |||
| auto host_x = gen({2, 3}, cns[0]); | |||
| HostTensorND host_gx, host_loss; | |||
| auto client = std::make_shared<MockGroupClient>(); | |||
| auto client = std::make_shared<test::MockGroupClient>(); | |||
| auto sender = [&]() { | |||
| sys::set_thread_name("sender"); | |||
| @@ -0,0 +1,62 @@ | |||
| /** | |||
| * \file src/opr-mm/test/mock_client.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/opr/group_manager.h" | |||
| namespace mgb { | |||
| namespace test { | |||
| class MockGroupClient final : public opr::GroupClient { | |||
| public: | |||
| using RegisterInfo = opr::GroupManager::RegisterInfo; | |||
| MockGroupClient(const std::string& server_addr = "mock_addr") : | |||
| m_addr(server_addr) { | |||
| } | |||
| ~MockGroupClient() override = default; | |||
| const std::string& get_addr() const { | |||
| return m_addr; | |||
| } | |||
| RegisterInfo opr_register(const std::string& key, size_t nr_devices, | |||
| bool is_root, int rank, uint64_t comp_node_hash) { | |||
| return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); | |||
| } | |||
| std::vector<std::string> gather_uid(const std::string& uid, | |||
| const std::string& key, uint32_t size, uint32_t rank) { | |||
| return m_mgr.gather_uid(uid, key, size, rank); | |||
| } | |||
| void set_output_shape(const std::string& key, | |||
| const TensorShape& shape) override { | |||
| m_mgr.set_output_shape(key, shape); | |||
| } | |||
| TensorShape get_output_shape(const std::string& key) override { | |||
| return m_mgr.get_output_shape(key); | |||
| } | |||
| uint32_t group_barrier(uint32_t size, uint32_t rank) override { | |||
| return m_mgr.group_barrier(size, rank); | |||
| } | |||
| private: | |||
| const std::string m_addr; | |||
| opr::GroupManager m_mgr; | |||
| }; | |||
| } // namespace test | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||