Browse Source

delete useless code for allreduce

tags/v0.6.0-beta
yangzhenzhang 5 years ago
parent
commit
e6cef98e95
2 changed files with 0 additions and 8 deletions
  1. +0
    -5
      mindspore/ops/operations/comm_ops.py
  2. +0
    -3
      tests/ut/cpp/parallel/step_parallel_test.cc

+ 0
- 5
mindspore/ops/operations/comm_ops.py View File

@@ -100,11 +100,6 @@ class AllReduce(PrimitiveWithInfer):
self.add_prim_attr('fusion', 0)
self.add_prim_attr('index', 0)

def vm_impl(self, x):
"""Implement by vm mode."""
x = x.asnumpy()
return Tensor(x)

def infer_shape(self, x_shape):
return x_shape



+ 0
- 3
tests/ut/cpp/parallel/step_parallel_test.cc View File

@@ -294,9 +294,6 @@ TEST_F(TestStepParallel, CreatOpInstance) {
ASSERT_TRUE(allreduce_ptr);
if (nullptr != allreduce_ptr) {
MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name();
if (!allreduce_ptr->HasComputeFunction()) {
MS_LOG(EXCEPTION) << "" << allreduce_ptr->name() << "'s compute function is not implemented";
}

std::vector<py::object> arglist;
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arglist),


Loading…
Cancel
Save