You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

collective_comm.cpp 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. /**
  2. * \file imperative/src/test/collective_comm.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./helper.h"
  12. #include "megbrain/imperative/ops/autogen.h"
  13. #include "megbrain/opr/mm_handler.h"
  14. using namespace mgb;
  15. using namespace imperative;
  16. TEST(TestImperative, AllReduceBasic) {
  17. REQUIRE_GPU(2);
  18. const char* server_addr = "127.0.0.1";
  19. uint32_t port = 3456;
  20. mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0);
  21. HostTensorGenerator<> gen;
  22. CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1");
  23. auto host_x = gen({233}, cn0), host_y = gen({233}, cn1);
  24. auto expect = gen({233});
  25. for (size_t i = 0; i < 233; ++i) {
  26. expect->ptr<float>()[i] = host_x->ptr<float>()[i] + host_y->ptr<float>()[i];
  27. }
  28. auto run = [&](std::shared_ptr<HostTensorND> hnd, uint32_t idx) {
  29. auto def = imperative::CollectiveComm::make(
  30. megdnn::param::CollectiveComm::Mode::ALL_REDUCE_SUM, "all_reduce", 2,
  31. idx, idx == 0, false, server_addr, port, dtype::Float32(), "nccl", "");
  32. auto inp = Tensor::make(*hnd);
  33. SmallVector<LogicalTensorDesc> output_descs;
  34. auto oup = OpDef::apply_on_physical_tensor(*def, {inp}, output_descs, false);
  35. HostTensorND host_v;
  36. host_v.copy_from(oup[0]->dev_tensor()).sync();
  37. MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6);
  38. };
  39. std::thread t0(std::bind(run, host_x, 0));
  40. std::thread t1(std::bind(run, host_y, 1));
  41. t0.join();
  42. t1.join();
  43. }
  44. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}