| @@ -31,6 +31,7 @@ | |||||
| #endif | #endif | ||||
| #if MEGDNN_WITH_CUDA | #if MEGDNN_WITH_CUDA | ||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #endif | #endif | ||||
| @@ -18,6 +18,7 @@ | |||||
| #endif | #endif | ||||
| #if MEGDNN_WITH_ROCM | #if MEGDNN_WITH_ROCM | ||||
| #include "src/rocm/megcore/computing_context.hpp" | #include "src/rocm/megcore/computing_context.hpp" | ||||
| #endif | #endif | ||||
| @@ -182,7 +182,8 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { | |||||
| } | } | ||||
| dev_type = DeviceType::MULTITHREAD; | dev_type = DeviceType::MULTITHREAD; | ||||
| ptr += 11; | ptr += 11; | ||||
| } else { | |||||
| } | |||||
| else { | |||||
| if (ptr[1] != 'p' || ptr[2] != 'u') { | if (ptr[1] != 'p' || ptr[2] != 'u') { | ||||
| err(); | err(); | ||||
| } | } | ||||
| @@ -237,7 +238,7 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { | |||||
| //! num_steam store the nr_thread | //! num_steam store the nr_thread | ||||
| std::swap(num_dev, num_stream); | std::swap(num_dev, num_stream); | ||||
| } | } | ||||
| return {dev_type, num_dev, {num_stream}}; | return {dev_type, num_dev, {num_stream}}; | ||||
| } | } | ||||
| @@ -1021,13 +1021,12 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by( | |||||
| { | { | ||||
| auto type = cn_impl->env().property().type; | auto type = cn_impl->env().property().type; | ||||
| mgb_throw_if( | |||||
| type != CompNode::DeviceType::CPU && | |||||
| type != CompNode::DeviceType::CUDA | |||||
| && type != CompNode::DeviceType::ATLAS && | |||||
| type != CompNode::DeviceType::CAMBRICON, | |||||
| MegBrainError, | |||||
| "currently CPU can only wait for CPU, CUDA, ATLAS, CAMBRICON" | |||||
| mgb_throw_if(type != CompNode::DeviceType::CPU | |||||
| && type != CompNode::DeviceType::CUDA | |||||
| && type != CompNode::DeviceType::ATLAS | |||||
| , | |||||
| MegBrainError, | |||||
| "currently CPU can only wait for CPU, CUDA, ATLAS" | |||||
| ); | ); | ||||
| } | } | ||||
| @@ -36,6 +36,7 @@ | |||||
| #endif | #endif | ||||
| using namespace mgb; | using namespace mgb; | ||||
| /* =================== MegDNNHandle =================== */ | /* =================== MegDNNHandle =================== */ | ||||
| @@ -232,6 +233,7 @@ void CompNodeEnv::init_cuda_async(int dev, CompNode comp_node, | |||||
| } | } | ||||
| #endif | #endif | ||||
| #if MGB_ATLAS | #if MGB_ATLAS | ||||
| void mgb::_on_atlas_error(const char* expr, int err, const char* file, | void mgb::_on_atlas_error(const char* expr, int err, const char* file, | ||||
| @@ -421,6 +423,7 @@ void CompNodeEnv::fini() { | |||||
| MGB_CUDA_CHECK(cudaStreamDestroy(m_cuda_env.stream)); | MGB_CUDA_CHECK(cudaStreamDestroy(m_cuda_env.stream)); | ||||
| } | } | ||||
| #endif | #endif | ||||
| #if MGB_ROCM | #if MGB_ROCM | ||||
| if (m_property.type == DeviceType::ROCM) { | if (m_property.type == DeviceType::ROCM) { | ||||
| m_rocm_env.activate(); | m_rocm_env.activate(); | ||||
| @@ -440,6 +443,7 @@ void CompNodeEnv::fini() { | |||||
| MGB_ATLAS_CHECK(aclrtDestroyStream(m_atlas_env.stream)); | MGB_ATLAS_CHECK(aclrtDestroyStream(m_atlas_env.stream)); | ||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| #if MGB_ENABLE_COMP_NODE_ASYNC_INIT | #if MGB_ENABLE_COMP_NODE_ASYNC_INIT | ||||
| @@ -73,6 +73,7 @@ std::string CudaError::get_cuda_extra_info() { | |||||
| #endif | #endif | ||||
| } | } | ||||
| AtlasError::AtlasError(const std::string &msg): | AtlasError::AtlasError(const std::string &msg): | ||||
| SystemError(msg) | SystemError(msg) | ||||
| { | { | ||||
| @@ -82,7 +82,7 @@ class CompNode { | |||||
| CAMBRICON = 3, | CAMBRICON = 3, | ||||
| ROCM = 8, | ROCM = 8, | ||||
| ATLAS = 9, | ATLAS = 9, | ||||
| MULTITHREAD, | |||||
| MULTITHREAD = 11, | |||||
| MAX_DEVICE_ID, | MAX_DEVICE_ID, | ||||
| }; | }; | ||||
| static constexpr size_t NR_DEVICE_TYPE = | static constexpr size_t NR_DEVICE_TYPE = | ||||
| @@ -63,6 +63,7 @@ | |||||
| #endif //MGB_ENABLE_LOGGING | #endif //MGB_ENABLE_LOGGING | ||||
| #endif //MGB_CUDA | #endif //MGB_CUDA | ||||
| #if MGB_ATLAS | #if MGB_ATLAS | ||||
| #include "megcore_atlas.h" | #include "megcore_atlas.h" | ||||
| #include <atomic> | #include <atomic> | ||||
| @@ -205,6 +206,7 @@ namespace mgb { | |||||
| #endif | #endif | ||||
| #if MGB_ROCM | #if MGB_ROCM | ||||
| [[noreturn]] void _on_hip_error(const char* expr, hipError_t err, | [[noreturn]] void _on_hip_error(const char* expr, hipError_t err, | ||||
| const char* file, const char* func, int line); | const char* file, const char* func, int line); | ||||
| @@ -369,6 +371,7 @@ public: | |||||
| const ContinuationCtx<cudaStream_t>& cont); | const ContinuationCtx<cudaStream_t>& cont); | ||||
| #endif | #endif | ||||
| #if MGB_ATLAS | #if MGB_ATLAS | ||||
| struct AtlasEnv { | struct AtlasEnv { | ||||
| int device = -1; | int device = -1; | ||||
| @@ -139,6 +139,11 @@ public: | |||||
| CudaError(const std::string& msg); | CudaError(const std::string& msg); | ||||
| }; | }; | ||||
| class EnFlameError final : public SystemError { | |||||
| public: | |||||
| EnFlameError(const std::string& msg); | |||||
| }; | |||||
| class AtlasError final: public SystemError { | class AtlasError final: public SystemError { | ||||
| public: | public: | ||||
| AtlasError(const std::string& msg); | AtlasError(const std::string& msg); | ||||
| @@ -166,6 +166,7 @@ TEST(TestCompNode, Load) { | |||||
| ASSERT_NE(atlas0, atlas1); | ASSERT_NE(atlas0, atlas1); | ||||
| #endif | #endif | ||||
| } | } | ||||
| TEST(TestCompNode, FreeAfterFinalize) { | TEST(TestCompNode, FreeAfterFinalize) { | ||||
| @@ -754,6 +755,7 @@ TEST(TestCompNodeCambricon, P2PCopy) { | |||||
| #endif | #endif | ||||
| #endif // MGB_CAMBRICON | #endif // MGB_CAMBRICON | ||||
| #if MGB_ATLAS | #if MGB_ATLAS | ||||
| TEST(TestCompNodeAtlas, D2DCopy) { | TEST(TestCompNodeAtlas, D2DCopy) { | ||||