Merge pull request !20960 from lichen/add_replace_graph_for_conv2dtags/v1.4.0
| @@ -100,6 +100,14 @@ AnfNodePtr CreatInt64Imm(int64_t value) { | |||
| return ValuePtrToAnfNodePtr(value_ptr); | |||
| } | |||
| AnfNodePtr CreatTuple(const std::vector<int64_t> &tuple) { | |||
| std::vector<ValuePtr> value_list; | |||
| std::transform(tuple.begin(), tuple.end(), std::back_inserter(value_list), | |||
| [](const int64_t value) { return MakeValue(value); }); | |||
| ValueTuplePtr value_tuple_ptr = std::make_shared<ValueTuple>(value_list); | |||
| return ValuePtrToAnfNodePtr(value_tuple_ptr); | |||
| } | |||
| std::string GetInstanceNameByCNode(const CNodePtr &cnode) { | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (!prim) { | |||
| @@ -41,6 +41,7 @@ AnfNodePtr CreatTypeInt(int64_t value); | |||
| AnfNodePtr CreatInt64Imm(int64_t value); | |||
| AnfNodePtr CreateInt32Tensor(int64_t value); | |||
| AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr); | |||
| AnfNodePtr CreatTuple(const std::vector<int64_t> &tuple); | |||
| std::string HashInstanceName(const std::string &name); | |||
| class GenerateGraph { | |||
| @@ -25,6 +25,7 @@ | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/strategy.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | |||
| #include "frontend/parallel/graph_util/generate_graph.h" | |||
| #include "pipeline/jit/resource.h" | |||
| namespace mindspore { | |||
| @@ -230,7 +231,7 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) { | |||
| if (weight_strategy[0] > 1) { | |||
| out_channel_shard_ = true; | |||
| new_out_channel_ = out_channel_ / weight_strategy[1]; | |||
| new_out_channel_ = out_channel_ / weight_strategy[0]; | |||
| } else { | |||
| out_channel_shard_ = false; | |||
| } | |||
| @@ -514,7 +515,7 @@ void Conv2DInfo::InferSendRecvFlag() { | |||
| MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the recv rank ids is " << recv_rank_ids_; | |||
| } | |||
| void Conv2DInfo::InferRecvShapes() { | |||
| void Conv2DInfo::InferOverlapShapes() { | |||
| if (left_need_recv_) { | |||
| Shape left_recv_shape = input_slice_shape_; | |||
| left_recv_shape[3] = overlap_left_size_; | |||
| @@ -535,6 +536,9 @@ void Conv2DInfo::InferStridedSliceAttrs() { | |||
| left_strided_slice_end_ = input_slice_shape_; | |||
| left_strided_slice_end_[3] = left_rank_overlap_right_size_; | |||
| left_strided_slice_strides_ = {1, 1, 1, 1}; | |||
| Shape left_send_shape = input_slice_shape_; | |||
| left_send_shape[3] = left_rank_overlap_right_size_; | |||
| send_shapes_.push_back(left_send_shape); | |||
| MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is " | |||
| << left_strided_slice_end_; | |||
| } | |||
| @@ -544,6 +548,9 @@ void Conv2DInfo::InferStridedSliceAttrs() { | |||
| right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_; | |||
| right_strided_slice_end_ = input_slice_shape_; | |||
| right_strided_slice_strides_ = {1, 1, 1, 1}; | |||
| Shape right_send_shape = input_slice_shape_; | |||
| right_send_shape[3] = right_rank_overlap_left_size_; | |||
| send_shapes_.push_back(right_send_shape); | |||
| MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is " | |||
| << right_strided_slice_end_; | |||
| } | |||
| @@ -554,11 +561,101 @@ void Conv2DInfo::InferNewOperatorAttrs() { | |||
| InferSendRecvFlag(); | |||
| InferRecvShapes(); | |||
| InferOverlapShapes(); | |||
| InferStridedSliceAttrs(); | |||
| } | |||
| OperatorAttrs Conv2DInfo::CreatNeighborExchangeAttrs(const CNodePtr &cnode) { | |||
| auto type = cnode->Type(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| auto tensor_type = type->cast<mindspore::TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| auto dtype = tensor_type->element(); | |||
| MS_EXCEPTION_IF_NULL(dtype); | |||
| Attr send_ranks = {SEND_RNAK_IDS, MakeValue(send_rank_ids_)}; | |||
| Attr recv_ranks = {RECV_RNAK_IDS, MakeValue(recv_rank_ids_)}; | |||
| Attr send_shapes = {SEND_SHAPES, MakeValue(send_shapes_)}; | |||
| Attr recv_shapes = {RECV_SHAPES, MakeValue(recv_shapes_)}; | |||
| Attr recv_type = {RECV_TYPE, dtype}; | |||
| OperatorAttrs attrs = {send_ranks, recv_ranks, recv_shapes, send_shapes, recv_type}; | |||
| return attrs; | |||
| } | |||
| OperatorAttrs Conv2DInfo::CreatConv2DAttrs() { | |||
| Attr out_channel = {OUT_CHANNEL, MakeValue(new_out_channel_)}; | |||
| Attr kernel_size = {KERNEL_SIZE, MakeValue(kernel_size_)}; | |||
| Attr mode = {MODE, MakeValue(mode_)}; | |||
| Attr pad_mode = {PAD_MODE, MakeValue("pad")}; | |||
| Attr pad = {PAD, MakeValue(new_pad_list_)}; | |||
| Attr stride = {STRIDE, MakeValue(stride_)}; | |||
| Attr dilation = {DILATION, MakeValue(dilation_)}; | |||
| Attr group = {GROUP, MakeValue(group_)}; | |||
| Attr data_format = {DATA_FORMAT, MakeValue(format_)}; | |||
| OperatorAttrs attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format}; | |||
| return attrs; | |||
| } | |||
| Status Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| auto graph = cnode->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| GenerateGraph gen_g = GenerateGraph(attrs_); | |||
| if (gen_g.Init(cnode) != SUCCESS) { | |||
| MS_LOG(ERROR) << "GenerateGraph Init failed"; | |||
| return FAILED; | |||
| } | |||
| std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes; | |||
| std::vector<AnfNodePtr> make_tuple_a_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| if (left_need_send_) { | |||
| auto slice_left_begin = CreatTuple(left_strided_slice_begin_); | |||
| auto slice_left_end = CreatTuple(left_strided_slice_end_); | |||
| auto slice_left_strided = CreatTuple(left_strided_slice_strides_); | |||
| auto slice_left = gen_g.PushBack( | |||
| {gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_left_begin, slice_left_end, slice_left_strided}); | |||
| make_tuple_a_inputs.push_back(slice_left); | |||
| } | |||
| if (right_need_send_) { | |||
| auto slice_right_begin = CreatTuple(right_strided_slice_begin_); | |||
| auto slice_right_end = CreatTuple(right_strided_slice_end_); | |||
| auto slice_right_strided = CreatTuple(right_strided_slice_strides_); | |||
| auto slice_right = gen_g.PushBack( | |||
| {gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_right_begin, slice_right_end, slice_right_strided}); | |||
| make_tuple_a_inputs.push_back(slice_right); | |||
| } | |||
| auto make_tuple_a = graph->NewCNode(make_tuple_a_inputs); | |||
| auto alltoall_attrs = CreatNeighborExchangeAttrs(cnode); | |||
| auto alltoall_v = gen_g.PushBack({gen_g.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a}); | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| if (left_need_recv_) { | |||
| std::vector<AnfNodePtr> tuple_getitem_l_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, | |||
| CreatInt64Imm(0)}; | |||
| auto tuple_getitem_l = graph->NewCNode(tuple_getitem_l_inputs); | |||
| std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), cnode->input(1), | |||
| tuple_getitem_l}; | |||
| auto make_tuple_l = graph->NewCNode(make_tuple_l_inputs); | |||
| auto concat_l = gen_g.PushBack({gen_g.NewOpInst(CONCAT), make_tuple_l}); | |||
| make_tuple_inputs.push_back(concat_l); | |||
| } | |||
| if (right_need_recv_) { | |||
| std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, | |||
| CreatInt64Imm(0)}; | |||
| auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs); | |||
| make_tuple_inputs.push_back(tuple_getitem_r); | |||
| } else { | |||
| make_tuple_inputs.push_back(cnode->input(1)); | |||
| } | |||
| auto make_tuple = graph->NewCNode(make_tuple_inputs); | |||
| Attr concat_axis = {AXIS, MakeValue(-1)}; | |||
| OperatorAttrs concat_attrs = {concat_axis}; | |||
| std::vector<AnfNodePtr> concat_inputs = {gen_g.NewOpInst(CONCAT, concat_attrs), make_tuple}; | |||
| auto concat = graph->NewCNode(concat_inputs); | |||
| auto conv2d_attrs = CreatConv2DAttrs(); | |||
| auto conv2d = gen_g.PushBack({gen_g.NewOpInst(CONV2D, conv2d_attrs), concat, cnode->input(2)}); | |||
| replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>( | |||
| std::make_pair(input_nodes, conv2d)); | |||
| return SUCCESS; | |||
| } | |||
| ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) { | |||
| if (!need_exchange_overlap_) { | |||
| if (!out_channel_shard_) { | |||
| @@ -579,6 +676,11 @@ ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) { | |||
| InferNewOperatorAttrs(); | |||
| if (ComputeReplaceGraph(cnode) != SUCCESS) { | |||
| return nullptr; | |||
| } else { | |||
| return replace_graph_; | |||
| } | |||
| return nullptr; | |||
| } | |||
| @@ -55,9 +55,12 @@ class Conv2DInfo : public OperatorInfo { | |||
| Status InferOverlapSize(); | |||
| void InferNewOperatorAttrs(); | |||
| void InferSendRecvFlag(); | |||
| void InferRecvShapes(); | |||
| void InferOverlapShapes(); | |||
| void InferStridedSliceAttrs(); | |||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | |||
| OperatorAttrs CreatNeighborExchangeAttrs(const CNodePtr &cnode); | |||
| OperatorAttrs CreatConv2DAttrs(); | |||
| Status ComputeReplaceGraph(const CNodePtr &cnode); | |||
| int64_t out_channel_ = 1; | |||
| std::vector<int64_t> kernel_size_; // two integers | |||
| @@ -100,6 +103,7 @@ class Conv2DInfo : public OperatorInfo { | |||
| std::vector<int64_t> send_rank_ids_; | |||
| std::vector<int64_t> recv_rank_ids_; | |||
| Shapes send_shapes_; | |||
| Shapes recv_shapes_; | |||
| virtual Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy); | |||
| @@ -156,6 +156,11 @@ constexpr char REPLACE[] = "replace"; | |||
| constexpr char CONNSYMBOL[] = "/"; | |||
| constexpr char INSTANCE_NAME[] = "instance_name"; | |||
| constexpr char SPLIT_SENS[] = "split_sens"; | |||
| constexpr char SEND_RNAK_IDS[] = "send_rank_ids"; | |||
| constexpr char RECV_RNAK_IDS[] = "recv_rank_ids"; | |||
| constexpr char RECV_SHAPES[] = "recv_shapes"; | |||
| constexpr char SEND_SHAPES[] = "send_shapes"; | |||
| constexpr char RECV_TYPE[] = "recv_type"; | |||
| constexpr char SPLIT_TENSOR[] = "split_tensor"; | |||
| constexpr char DEV_MAT[] = "dev_mat"; | |||
| constexpr char TENSOR_MAP[] = "tensor_map"; | |||
| @@ -195,6 +200,8 @@ constexpr char KERNEL_SIZE[] = "kernel_size"; | |||
| constexpr char MODE[] = "mode"; | |||
| constexpr char PAD_MODE[] = "pad_mode"; | |||
| constexpr char PAD_LIST[] = "pad_list"; | |||
| constexpr char PAD[] = "pad"; | |||
| constexpr char DATA_FORMAT[] = "data_format"; | |||
| constexpr char STRIDE[] = "stride"; | |||
| constexpr char DILATION[] = "dilation"; | |||
| constexpr char FORMAT[] = "format"; | |||
| @@ -209,6 +216,7 @@ constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; | |||
| constexpr char GET_TENSOR_SLICE[] = "_GetTensorSlice"; | |||
| constexpr char SPLIT[] = "Split"; | |||
| constexpr char ALL_TO_ALL[] = "_AlltoAll"; | |||
| constexpr char NEIGHBOREXCHANGE[] = "NeighborExchange"; | |||
| constexpr char PERMUTE_BY_AXIS[] = "PermuteByAxis"; | |||
| constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis"; | |||
| constexpr char SPLIT_BY_AXIS[] = "SplitByAxis"; | |||
| @@ -388,7 +388,7 @@ inline const PrimitivePtr kPrimVirtualOutput = std::make_shared<Primitive>("_Vir | |||
| inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send"); | |||
| inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive"); | |||
| inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||
| inline const PrimitivePtr kPrimAllToAllv = std::make_shared<Primitive>("AllToAllv"); | |||
| inline const PrimitivePtr kPrimNeighborExchange = std::make_shared<Primitive>("NeighborExchange"); | |||
| inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap"); | |||
| inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast"); | |||
| inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather"); | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/alltoallv.h" | |||
| #include "ops/neighborexchange.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| @@ -46,7 +46,7 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec | |||
| TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("AllToAllv infer", input_args.size(), kEqual, 1, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("NeighborExchange infer", input_args.size(), kEqual, 1, prim_name); | |||
| MS_EXCEPTION_IF_NULL(input_args[0]); | |||
| auto recv_shapes = primitive->GetAttr(RecvShapes); | |||
| MS_EXCEPTION_IF_NULL(recv_shapes); | |||
| @@ -60,13 +60,13 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP | |||
| return std::make_shared<Tuple>(type_vec); | |||
| } | |||
| AbstractBasePtr AllToAllvInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| auto type = InferType(primitive, input_args); | |||
| auto shape = InferShape(primitive, input_args); | |||
| return abstract::MakeAbstract(shape, type); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(AllToAllv, prim::kPrimAllToAllv, AllToAllvInfer, nullptr, true); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_ALLTOALLV_H_ | |||
| #define MINDSPORE_CORE_OPS_ALLTOALLV_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_NEIGHBOREXCHANGE_H_ | |||
| #define MINDSPORE_CORE_OPS_NEIGHBOREXCHANGE_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| @@ -24,20 +24,20 @@ | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAllToAllv = "AllToAllv"; | |||
| constexpr auto kNameNeighborExchange = "NeighborExchange"; | |||
| constexpr auto RecvShapes = "recv_shapes"; | |||
| constexpr auto RecvType = "recv_type"; | |||
| class AllToAllv : public PrimitiveC { | |||
| class NeighborExchange : public PrimitiveC { | |||
| public: | |||
| AllToAllv() : PrimitiveC(kNameAllToAllv) {} | |||
| ~AllToAllv() = default; | |||
| MS_DECLARE_PARENT(AllToAllv, PrimitiveC); | |||
| NeighborExchange() : PrimitiveC(kNameNeighborExchange) {} | |||
| ~NeighborExchange() = default; | |||
| MS_DECLARE_PARENT(NeighborExchange, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr AllToAllvInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimAllToAllPtr = std::shared_ptr<AllToAllv>; | |||
| AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimNeighborExchangePtr = std::shared_ptr<NeighborExchange>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_ALLTOALLV_H_ | |||
| #endif // MINDSPORE_CORE_OPS_NEIGHBOREXCHANGE_H_ | |||
| @@ -15,19 +15,19 @@ | |||
| """Generate bprop for comm ops""" | |||
| from .._grad.grad_base import bprop_getters | |||
| from ..operations._inner_ops import AllToAllv | |||
| from ..operations._inner_ops import NeighborExchange | |||
| @bprop_getters.register(AllToAllv) | |||
| def get_bprop_alltoallv(self): | |||
| """Generate bprop for AllToAllv.""" | |||
| @bprop_getters.register(NeighborExchange) | |||
| def get_bprop_neighborexchange(self): | |||
| """Generate bprop for NeighborExchange.""" | |||
| group = self.group | |||
| send_rank_ids = self.recv_rank_ids | |||
| recv_rank_ids = self.send_rank_ids | |||
| recv_shapes = self.recv_shapes_backward | |||
| recv_shapes = self.send_shapes | |||
| recv_type = self.recv_type | |||
| alltoallv_grad = AllToAllv(send_rank_ids, recv_rank_ids, recv_shapes, recv_shapes, recv_type, group) | |||
| neighborexchange_grad = NeighborExchange(send_rank_ids, recv_rank_ids, recv_shapes, recv_shapes, recv_type, group) | |||
| def bprop(x, out, dout): | |||
| return (alltoallv_grad(dout),) | |||
| return (neighborexchange_grad(dout),) | |||
| return bprop | |||
| @@ -492,29 +492,30 @@ class Receive(PrimitiveWithInfer): | |||
| return self.dtype | |||
| class AllToAllv(Primitive): | |||
| class NeighborExchange(Primitive): | |||
| """ | |||
| AlltoAllv is a collective operation. | |||
| NeighborExchange is a collective operation. | |||
| AlltoAllv sends data from the local rank to ranks in the send_rank_ids, as while receive data from recv_rank_ids. | |||
| NeighborExchange sends data from the local rank to ranks in the send_rank_ids, | |||
| as while receive data from recv_rank_ids. | |||
| Args: | |||
| send_rank_ids (list): Ranks which the data is sent to. | |||
| recv_rank_ids (list): Ranks which the data is received from. | |||
| recv_shapes (list): Data shape which received from recv_rank_ids. | |||
| recv_shapes_backward (list): Data shape which received from send_rank_ids in the backward. | |||
| send_shapes (list): Data shape which send to the send_rank_ids. | |||
| recv_type (type): Data type which received from recv_rank_ids | |||
| group (str): | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, send_rank_ids, recv_rank_ids, recv_shapes, recv_shapes_backward, recv_type, | |||
| def __init__(self, send_rank_ids, recv_rank_ids, recv_shapes, send_shapes, recv_type, | |||
| group=GlobalComm.WORLD_COMM_GROUP): | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| self.send_rank_ids = send_rank_ids | |||
| self.recv_rank_ids = recv_rank_ids | |||
| self.recv_shapes = recv_shapes | |||
| self.recv_shapes_backward = recv_shapes_backward | |||
| self.send_shapes = send_shapes | |||
| self.recv_type = recv_type | |||
| @@ -20,7 +20,7 @@ import mindspore.nn as nn | |||
| from mindspore.common.api import _executor | |||
| from mindspore.nn import TrainOneStepCell, Momentum | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations._inner_ops import AllToAllv | |||
| from mindspore.ops.operations._inner_ops import NeighborExchange | |||
| class MatMulNet(nn.Cell): | |||
| @@ -28,8 +28,8 @@ class MatMulNet(nn.Cell): | |||
| super(MatMulNet, self).__init__() | |||
| self.matmul = P.MatMul() | |||
| self.mul = P.Mul() | |||
| self.alltoallv = AllToAllv(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]), | |||
| recv_shapes_backward=([32, 32], [32, 16]), recv_type=ms.float32) | |||
| self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]), | |||
| send_shapes=([32, 32], [32, 16]), recv_type=ms.float32) | |||
| self.weight1 = Parameter(weight1, "w1") | |||
| def construct(self, x1, x2): | |||
| @@ -44,8 +44,8 @@ class MatMulNet2(nn.Cell): | |||
| super(MatMulNet2, self).__init__() | |||
| self.matmul = P.MatMul() | |||
| self.mul = P.Mul() | |||
| self.alltoallv = AllToAllv(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]), | |||
| recv_shapes_backward=([32, 32],), recv_type=ms.float32) | |||
| self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]), | |||
| send_shapes=([32, 32],), recv_type=ms.float32) | |||
| self.weight1 = Parameter(weight1, "w1") | |||
| def construct(self, x1, x2): | |||
| @@ -68,13 +68,13 @@ def compile_net(net): | |||
| _executor.compile(train_net, _x1, _x2) | |||
| def test_AllToAllv_two_inputs(): | |||
| def test_NeighborExchange_two_inputs(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| net = MatMulNet(_w1) | |||
| compile_net(net) | |||
| def test_AllToAllv_single_input(): | |||
| def test_NeighborExchange_single_input(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| net = MatMulNet2(_w1) | |||
| compile_net(net) | |||