Browse Source

!20960 [AutoParallel]Add replace graph for conv2d

Merge pull request !20960 from lichen/add_replace_graph_for_conv2d
tags/v1.4.0
i-robot Gitee 4 years ago
parent
commit
9f296c58d6
11 changed files with 165 additions and 41 deletions
  1. +8
    -0
      mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc
  2. +1
    -0
      mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h
  3. +105
    -3
      mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc
  4. +5
    -1
      mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h
  5. +8
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h
  6. +1
    -1
      mindspore/core/base/core_ops.h
  7. +5
    -5
      mindspore/core/ops/neighborexchange.cc
  8. +11
    -11
      mindspore/core/ops/neighborexchange.h
  9. +7
    -7
      mindspore/ops/_grad_experimental/grad_comm_ops.py
  10. +7
    -6
      mindspore/ops/operations/_inner_ops.py
  11. +7
    -7
      tests/ut/python/parallel/test_neighborexchange.py

+ 8
- 0
mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc View File

@@ -100,6 +100,14 @@ AnfNodePtr CreatInt64Imm(int64_t value) {
return ValuePtrToAnfNodePtr(value_ptr);
}

AnfNodePtr CreatTuple(const std::vector<int64_t> &tuple) {
std::vector<ValuePtr> value_list;
std::transform(tuple.begin(), tuple.end(), std::back_inserter(value_list),
[](const int64_t value) { return MakeValue(value); });
ValueTuplePtr value_tuple_ptr = std::make_shared<ValueTuple>(value_list);
return ValuePtrToAnfNodePtr(value_tuple_ptr);
}

std::string GetInstanceNameByCNode(const CNodePtr &cnode) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (!prim) {


+ 1
- 0
mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h View File

@@ -41,6 +41,7 @@ AnfNodePtr CreatTypeInt(int64_t value);
AnfNodePtr CreatInt64Imm(int64_t value);
AnfNodePtr CreateInt32Tensor(int64_t value);
AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr);
AnfNodePtr CreatTuple(const std::vector<int64_t> &tuple);
std::string HashInstanceName(const std::string &name);

class GenerateGraph {


+ 105
- 3
mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc View File

@@ -25,6 +25,7 @@
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "pipeline/jit/resource.h"

namespace mindspore {
@@ -230,7 +231,7 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {

if (weight_strategy[0] > 1) {
out_channel_shard_ = true;
new_out_channel_ = out_channel_ / weight_strategy[1];
new_out_channel_ = out_channel_ / weight_strategy[0];
} else {
out_channel_shard_ = false;
}
@@ -514,7 +515,7 @@ void Conv2DInfo::InferSendRecvFlag() {
MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the recv rank ids is " << recv_rank_ids_;
}

void Conv2DInfo::InferRecvShapes() {
void Conv2DInfo::InferOverlapShapes() {
if (left_need_recv_) {
Shape left_recv_shape = input_slice_shape_;
left_recv_shape[3] = overlap_left_size_;
@@ -535,6 +536,9 @@ void Conv2DInfo::InferStridedSliceAttrs() {
left_strided_slice_end_ = input_slice_shape_;
left_strided_slice_end_[3] = left_rank_overlap_right_size_;
left_strided_slice_strides_ = {1, 1, 1, 1};
Shape left_send_shape = input_slice_shape_;
left_send_shape[3] = left_rank_overlap_right_size_;
send_shapes_.push_back(left_send_shape);
MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is "
<< left_strided_slice_end_;
}
@@ -544,6 +548,9 @@ void Conv2DInfo::InferStridedSliceAttrs() {
right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_;
right_strided_slice_end_ = input_slice_shape_;
right_strided_slice_strides_ = {1, 1, 1, 1};
Shape right_send_shape = input_slice_shape_;
right_send_shape[3] = right_rank_overlap_left_size_;
send_shapes_.push_back(right_send_shape);
MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is "
<< right_strided_slice_end_;
}
@@ -554,11 +561,101 @@ void Conv2DInfo::InferNewOperatorAttrs() {

InferSendRecvFlag();

InferRecvShapes();
InferOverlapShapes();

InferStridedSliceAttrs();
}

OperatorAttrs Conv2DInfo::CreatNeighborExchangeAttrs(const CNodePtr &cnode) {
auto type = cnode->Type();
MS_EXCEPTION_IF_NULL(type);
auto tensor_type = type->cast<mindspore::TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto dtype = tensor_type->element();
MS_EXCEPTION_IF_NULL(dtype);
Attr send_ranks = {SEND_RNAK_IDS, MakeValue(send_rank_ids_)};
Attr recv_ranks = {RECV_RNAK_IDS, MakeValue(recv_rank_ids_)};
Attr send_shapes = {SEND_SHAPES, MakeValue(send_shapes_)};
Attr recv_shapes = {RECV_SHAPES, MakeValue(recv_shapes_)};
Attr recv_type = {RECV_TYPE, dtype};
OperatorAttrs attrs = {send_ranks, recv_ranks, recv_shapes, send_shapes, recv_type};
return attrs;
}

OperatorAttrs Conv2DInfo::CreatConv2DAttrs() {
Attr out_channel = {OUT_CHANNEL, MakeValue(new_out_channel_)};
Attr kernel_size = {KERNEL_SIZE, MakeValue(kernel_size_)};
Attr mode = {MODE, MakeValue(mode_)};
Attr pad_mode = {PAD_MODE, MakeValue("pad")};
Attr pad = {PAD, MakeValue(new_pad_list_)};
Attr stride = {STRIDE, MakeValue(stride_)};
Attr dilation = {DILATION, MakeValue(dilation_)};
Attr group = {GROUP, MakeValue(group_)};
Attr data_format = {DATA_FORMAT, MakeValue(format_)};
OperatorAttrs attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format};
return attrs;
}

Status Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
auto graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(graph);
GenerateGraph gen_g = GenerateGraph(attrs_);
if (gen_g.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << "GenerateGraph Init failed";
return FAILED;
}
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes;
std::vector<AnfNodePtr> make_tuple_a_inputs = {NewValueNode(prim::kPrimMakeTuple)};
if (left_need_send_) {
auto slice_left_begin = CreatTuple(left_strided_slice_begin_);
auto slice_left_end = CreatTuple(left_strided_slice_end_);
auto slice_left_strided = CreatTuple(left_strided_slice_strides_);
auto slice_left = gen_g.PushBack(
{gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_left_begin, slice_left_end, slice_left_strided});
make_tuple_a_inputs.push_back(slice_left);
}
if (right_need_send_) {
auto slice_right_begin = CreatTuple(right_strided_slice_begin_);
auto slice_right_end = CreatTuple(right_strided_slice_end_);
auto slice_right_strided = CreatTuple(right_strided_slice_strides_);
auto slice_right = gen_g.PushBack(
{gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_right_begin, slice_right_end, slice_right_strided});
make_tuple_a_inputs.push_back(slice_right);
}
auto make_tuple_a = graph->NewCNode(make_tuple_a_inputs);
auto alltoall_attrs = CreatNeighborExchangeAttrs(cnode);
auto alltoall_v = gen_g.PushBack({gen_g.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a});
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
if (left_need_recv_) {
std::vector<AnfNodePtr> tuple_getitem_l_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
CreatInt64Imm(0)};
auto tuple_getitem_l = graph->NewCNode(tuple_getitem_l_inputs);
std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), cnode->input(1),
tuple_getitem_l};
auto make_tuple_l = graph->NewCNode(make_tuple_l_inputs);
auto concat_l = gen_g.PushBack({gen_g.NewOpInst(CONCAT), make_tuple_l});
make_tuple_inputs.push_back(concat_l);
}
if (right_need_recv_) {
std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
CreatInt64Imm(0)};
auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs);
make_tuple_inputs.push_back(tuple_getitem_r);
} else {
make_tuple_inputs.push_back(cnode->input(1));
}
auto make_tuple = graph->NewCNode(make_tuple_inputs);
Attr concat_axis = {AXIS, MakeValue(-1)};
OperatorAttrs concat_attrs = {concat_axis};
std::vector<AnfNodePtr> concat_inputs = {gen_g.NewOpInst(CONCAT, concat_attrs), make_tuple};
auto concat = graph->NewCNode(concat_inputs);
auto conv2d_attrs = CreatConv2DAttrs();
auto conv2d = gen_g.PushBack({gen_g.NewOpInst(CONV2D, conv2d_attrs), concat, cnode->input(2)});
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(input_nodes, conv2d));
return SUCCESS;
}

ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
if (!need_exchange_overlap_) {
if (!out_channel_shard_) {
@@ -579,6 +676,11 @@ ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {

InferNewOperatorAttrs();

if (ComputeReplaceGraph(cnode) != SUCCESS) {
return nullptr;
} else {
return replace_graph_;
}
return nullptr;
}



+ 5
- 1
mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h View File

@@ -55,9 +55,12 @@ class Conv2DInfo : public OperatorInfo {
Status InferOverlapSize();
void InferNewOperatorAttrs();
void InferSendRecvFlag();
void InferRecvShapes();
void InferOverlapShapes();
void InferStridedSliceAttrs();
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
OperatorAttrs CreatNeighborExchangeAttrs(const CNodePtr &cnode);
OperatorAttrs CreatConv2DAttrs();
Status ComputeReplaceGraph(const CNodePtr &cnode);

int64_t out_channel_ = 1;
std::vector<int64_t> kernel_size_; // two integers
@@ -100,6 +103,7 @@ class Conv2DInfo : public OperatorInfo {

std::vector<int64_t> send_rank_ids_;
std::vector<int64_t> recv_rank_ids_;
Shapes send_shapes_;
Shapes recv_shapes_;

virtual Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);


+ 8
- 0
mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h View File

@@ -156,6 +156,11 @@ constexpr char REPLACE[] = "replace";
constexpr char CONNSYMBOL[] = "/";
constexpr char INSTANCE_NAME[] = "instance_name";
constexpr char SPLIT_SENS[] = "split_sens";
constexpr char SEND_RNAK_IDS[] = "send_rank_ids";
constexpr char RECV_RNAK_IDS[] = "recv_rank_ids";
constexpr char RECV_SHAPES[] = "recv_shapes";
constexpr char SEND_SHAPES[] = "send_shapes";
constexpr char RECV_TYPE[] = "recv_type";
constexpr char SPLIT_TENSOR[] = "split_tensor";
constexpr char DEV_MAT[] = "dev_mat";
constexpr char TENSOR_MAP[] = "tensor_map";
@@ -195,6 +200,8 @@ constexpr char KERNEL_SIZE[] = "kernel_size";
constexpr char MODE[] = "mode";
constexpr char PAD_MODE[] = "pad_mode";
constexpr char PAD_LIST[] = "pad_list";
constexpr char PAD[] = "pad";
constexpr char DATA_FORMAT[] = "data_format";
constexpr char STRIDE[] = "stride";
constexpr char DILATION[] = "dilation";
constexpr char FORMAT[] = "format";
@@ -209,6 +216,7 @@ constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
constexpr char GET_TENSOR_SLICE[] = "_GetTensorSlice";
constexpr char SPLIT[] = "Split";
constexpr char ALL_TO_ALL[] = "_AlltoAll";
constexpr char NEIGHBOREXCHANGE[] = "NeighborExchange";
constexpr char PERMUTE_BY_AXIS[] = "PermuteByAxis";
constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis";
constexpr char SPLIT_BY_AXIS[] = "SplitByAxis";


+ 1
- 1
mindspore/core/base/core_ops.h View File

@@ -388,7 +388,7 @@ inline const PrimitivePtr kPrimVirtualOutput = std::make_shared<Primitive>("_Vir
inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");
inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive");
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
inline const PrimitivePtr kPrimAllToAllv = std::make_shared<Primitive>("AllToAllv");
inline const PrimitivePtr kPrimNeighborExchange = std::make_shared<Primitive>("NeighborExchange");
inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap");
inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast");
inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather");


mindspore/core/ops/alltoallv.cc → mindspore/core/ops/neighborexchange.cc View File

@@ -14,7 +14,7 @@
* limitations under the License.
*/

#include "ops/alltoallv.h"
#include "ops/neighborexchange.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
@@ -46,7 +46,7 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("AllToAllv infer", input_args.size(), kEqual, 1, prim_name);
CheckAndConvertUtils::CheckInteger("NeighborExchange infer", input_args.size(), kEqual, 1, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto recv_shapes = primitive->GetAttr(RecvShapes);
MS_EXCEPTION_IF_NULL(recv_shapes);
@@ -60,13 +60,13 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
return std::make_shared<Tuple>(type_vec);
}

AbstractBasePtr AllToAllvInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}

REGISTER_PRIMITIVE_EVAL_IMPL(AllToAllv, prim::kPrimAllToAllv, AllToAllvInfer, nullptr, true);
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

mindspore/core/ops/alltoallv.h → mindspore/core/ops/neighborexchange.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_OPS_ALLTOALLV_H_
#define MINDSPORE_CORE_OPS_ALLTOALLV_H_
#ifndef MINDSPORE_CORE_OPS_NEIGHBOREXCHANGE_H_
#define MINDSPORE_CORE_OPS_NEIGHBOREXCHANGE_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
@@ -24,20 +24,20 @@

namespace mindspore {
namespace ops {
constexpr auto kNameAllToAllv = "AllToAllv";
constexpr auto kNameNeighborExchange = "NeighborExchange";
constexpr auto RecvShapes = "recv_shapes";
constexpr auto RecvType = "recv_type";
class AllToAllv : public PrimitiveC {
class NeighborExchange : public PrimitiveC {
public:
AllToAllv() : PrimitiveC(kNameAllToAllv) {}
~AllToAllv() = default;
MS_DECLARE_PARENT(AllToAllv, PrimitiveC);
NeighborExchange() : PrimitiveC(kNameNeighborExchange) {}
~NeighborExchange() = default;
MS_DECLARE_PARENT(NeighborExchange, PrimitiveC);
void Init() {}
};
AbstractBasePtr AllToAllvInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAllToAllPtr = std::shared_ptr<AllToAllv>;
AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimNeighborExchangePtr = std::shared_ptr<NeighborExchange>;
} // namespace ops
} // namespace mindspore

#endif // MINDSPORE_CORE_OPS_ALLTOALLV_H_
#endif // MINDSPORE_CORE_OPS_NEIGHBOREXCHANGE_H_

+ 7
- 7
mindspore/ops/_grad_experimental/grad_comm_ops.py View File

@@ -15,19 +15,19 @@

"""Generate bprop for comm ops"""
from .._grad.grad_base import bprop_getters
from ..operations._inner_ops import AllToAllv
from ..operations._inner_ops import NeighborExchange


@bprop_getters.register(AllToAllv)
def get_bprop_alltoallv(self):
"""Generate bprop for AllToAllv."""
@bprop_getters.register(NeighborExchange)
def get_bprop_neighborexchange(self):
"""Generate bprop for NeighborExchange."""
group = self.group
send_rank_ids = self.recv_rank_ids
recv_rank_ids = self.send_rank_ids
recv_shapes = self.recv_shapes_backward
recv_shapes = self.send_shapes
recv_type = self.recv_type
alltoallv_grad = AllToAllv(send_rank_ids, recv_rank_ids, recv_shapes, recv_shapes, recv_type, group)
neighborexchange_grad = NeighborExchange(send_rank_ids, recv_rank_ids, recv_shapes, recv_shapes, recv_type, group)

def bprop(x, out, dout):
return (alltoallv_grad(dout),)
return (neighborexchange_grad(dout),)
return bprop

+ 7
- 6
mindspore/ops/operations/_inner_ops.py View File

@@ -492,29 +492,30 @@ class Receive(PrimitiveWithInfer):
return self.dtype


class AllToAllv(Primitive):
class NeighborExchange(Primitive):
"""
AlltoAllv is a collective operation.
NeighborExchange is a collective operation.

AlltoAllv sends data from the local rank to ranks in the send_rank_ids, as while receive data from recv_rank_ids.
NeighborExchange sends data from the local rank to ranks in the send_rank_ids,
as while receive data from recv_rank_ids.

Args:
send_rank_ids (list): Ranks which the data is sent to.
recv_rank_ids (list): Ranks which the data is received from.
recv_shapes (list): Data shape which received from recv_rank_ids.
recv_shapes_backward (list): Data shape which received from send_rank_ids in the backward.
send_shapes (list): Data shape which send to the send_rank_ids.
recv_type (type): Data type which received from recv_rank_ids
group (str):
"""

@prim_attr_register
def __init__(self, send_rank_ids, recv_rank_ids, recv_shapes, recv_shapes_backward, recv_type,
def __init__(self, send_rank_ids, recv_rank_ids, recv_shapes, send_shapes, recv_type,
group=GlobalComm.WORLD_COMM_GROUP):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
self.send_rank_ids = send_rank_ids
self.recv_rank_ids = recv_rank_ids
self.recv_shapes = recv_shapes
self.recv_shapes_backward = recv_shapes_backward
self.send_shapes = send_shapes
self.recv_type = recv_type




tests/ut/python/parallel/test_alltoall_v.py → tests/ut/python/parallel/test_neighborexchange.py View File

@@ -20,7 +20,7 @@ import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, Momentum
from mindspore.ops import operations as P
from mindspore.ops.operations._inner_ops import AllToAllv
from mindspore.ops.operations._inner_ops import NeighborExchange


class MatMulNet(nn.Cell):
@@ -28,8 +28,8 @@ class MatMulNet(nn.Cell):
super(MatMulNet, self).__init__()
self.matmul = P.MatMul()
self.mul = P.Mul()
self.alltoallv = AllToAllv(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]),
recv_shapes_backward=([32, 32], [32, 16]), recv_type=ms.float32)
self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]),
send_shapes=([32, 32], [32, 16]), recv_type=ms.float32)
self.weight1 = Parameter(weight1, "w1")

def construct(self, x1, x2):
@@ -44,8 +44,8 @@ class MatMulNet2(nn.Cell):
super(MatMulNet2, self).__init__()
self.matmul = P.MatMul()
self.mul = P.Mul()
self.alltoallv = AllToAllv(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]),
recv_shapes_backward=([32, 32],), recv_type=ms.float32)
self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]),
send_shapes=([32, 32],), recv_type=ms.float32)
self.weight1 = Parameter(weight1, "w1")

def construct(self, x1, x2):
@@ -68,13 +68,13 @@ def compile_net(net):
_executor.compile(train_net, _x1, _x2)


def test_AllToAllv_two_inputs():
def test_NeighborExchange_two_inputs():
context.set_auto_parallel_context(device_num=8, global_rank=0)
net = MatMulNet(_w1)
compile_net(net)


def test_AllToAllv_single_input():
def test_NeighborExchange_single_input():
context.set_auto_parallel_context(device_num=8, global_rank=0)
net = MatMulNet2(_w1)
compile_net(net)

Loading…
Cancel
Save