|
- #pragma once
-
- #include <memory>
- #include <mutex>
-
- #include "megbrain/comp_node.h"
- #include "megbrain/opr/group_manager.h"
- #include "megray.h"
-
- namespace mgb {
- namespace opr {
-
- MegRay::DType get_megray_dtype(megdnn::DType);
-
- MegRay::Backend get_megray_backend(const std::string& backend);
-
- std::shared_ptr<MegRay::Context> get_megray_context(CompNode comp_node);
-
- /*!
- * gather MegRay unique ids and build communicator, use hash for deduplication
- */
- class MegRayCommBuilder {
- private:
- bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm);
- void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm);
- void remove(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm);
-
- std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms;
- std::mutex m_map_mtx;
-
- static MegRayCommBuilder* sm_instance;
- static std::mutex sm_instance_mtx;
-
- public:
- static std::shared_ptr<MegRay::Communicator> get_megray_comm(
- uint64_t hash, std::string key, uint32_t size, uint32_t rank,
- MegRay::Backend backend,
- std::shared_ptr<mgb::opr::GroupClient> group_client);
- };
-
- } // namespace opr
- } // namespace mgb
-
- // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
|