|
|
|
@@ -29,7 +29,7 @@ |
|
|
|
#include "parallel/context.h" |
|
|
|
#include "parallel/device_manager.h" |
|
|
|
#include "parallel/costmodel_context.h" |
|
|
|
#ifdef ENABLE_GPUQUE |
|
|
|
#ifdef ENABLE_GPU_COLLECTIVE |
|
|
|
#include "device/gpu/distribution/collective_init.h" |
|
|
|
#else |
|
|
|
#include "device/gpu/distribution/collective_fake_init.h" |
|
|
|
@@ -300,7 +300,7 @@ PYBIND11_MODULE(_c_expression, m) { |
|
|
|
(void)py::class_<OpLib, std::shared_ptr<OpLib>>(m, "Oplib") |
|
|
|
.def(py::init()) |
|
|
|
.def("reg_op", &OpLib::RegOp, "Register op info."); |
|
|
|
#ifdef ENABLE_GPUQUE |
|
|
|
#ifdef ENABLE_GPU_COLLECTIVE |
|
|
|
(void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::InitCollective, |
|
|
|
"Init gpu collective communication mode."); |
|
|
|
(void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::FinalizeCollective, |
|
|
|
|