diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index b8699f32fa..a7fbdd56e8 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -90,10 +90,6 @@ class VirtualDatasetEliminater : public AnfVisitor { std::vector args; (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); - if (args.size() == 1) { - return args.front(); - } - (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple)); return node->func_graph()->NewCNode(args); diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 37b730eb90..c1fb36635c 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -304,9 +304,9 @@ class _VirtualDatasetCell(Cell): self._backbone = backbone self._virtual_dataset = _VirtualDataset() - def construct(self, data, label): - data_, label_ = self._virtual_dataset(data, label) - return self._backbone(data_, label_) + def construct(self, *inputs): + output = self._virtual_dataset(*inputs) + return self._backbone(*output) class VirtualDatasetCellTriple(Cell): diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 1cf5a61299..c2bfa1f178 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -689,13 +689,9 @@ class _VirtualDataset(PrimitiveWithInfer): """init""" def infer_shape(self, *args): - if len(args) == 1: - return args[0] return args def infer_dtype(self, *args): - if len(args) == 1: - return args[0] return args diff --git a/tests/ut/python/parallel/test_auto_parallel_resnet_predict.py b/tests/ut/python/parallel/test_auto_parallel_resnet_predict.py new file mode 100644 index 0000000000..5f2aeeb3f6 --- /dev/null +++ b/tests/ut/python/parallel/test_auto_parallel_resnet_predict.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from mindspore import Tensor +from mindspore import context +from mindspore.communication.management import init +from mindspore.parallel import set_algo_parameters +from mindspore.train.model import Model +from mindspore.context import ParallelMode +from .test_auto_parallel_resnet import resnet50 + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +context.set_context(device_id=0) +init() + +def test_train_32k_8p(batch_size=32, num_classes=32768): + dev_num = 8 + context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) + set_algo_parameters(elementwise_op_strategy_follow=True) + np.random.seed(6) + input_np = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32)) + net = resnet50(num_classes) + model = Model(net) + model.predict(input_np) diff --git a/tests/ut/python/parallel/test_onehot.py b/tests/ut/python/parallel/test_onehot.py index d39fe28ff6..89d420660c 100644 --- a/tests/ut/python/parallel/test_onehot.py +++ b/tests/ut/python/parallel/test_onehot.py @@ -21,7 +21,7 @@ from mindspore import context from mindspore.common.api import _executor from mindspore.ops import composite as C from mindspore.ops import operations as P -from mindspore.ops.operations.comm_ops import _VirtualDataset +from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell context.set_context(mode=context.GRAPH_MODE) @@ -32,7 +32,6 @@ grad_all = C.GradOperation(get_all=True) class NetWithLoss(nn.Cell): def __init__(self, network, strategy3, strategy4, axis): super(NetWithLoss, self).__init__() - self.virtual_dataset = _VirtualDataset() self.one_hot = P.OneHot(axis=axis).shard(strategy3) self.on_value = Tensor(2.0, ms.float32) self.off_value = Tensor(1.0, ms.float32) @@ -40,9 +39,8 @@ class NetWithLoss(nn.Cell): self.network = network def construct(self, x, y, b): - b_virtual = self.virtual_dataset(b) predict = self.network(x, y) - label = self.one_hot(b_virtual, 64, self.on_value, self.off_value) + label = self.one_hot(b, 64, self.on_value, self.off_value) return self.loss(predict, label)[0] @@ -68,7 +66,7 @@ class Net(nn.Cell): def compile_graph(strategy1, strategy2, strategy3, strategy4, auto=False, onthot_axis=-1): - net = GradWrap(NetWithLoss(Net(strategy1, strategy2), strategy3, strategy4, axis=onthot_axis)) + net = GradWrap(_VirtualDatasetCell(NetWithLoss(Net(strategy1, strategy2), strategy3, strategy4, axis=onthot_axis))) net.set_auto_parallel() if auto: context.set_auto_parallel_context(parallel_mode="auto_parallel") diff --git a/tests/ut/python/parallel/test_reshape.py b/tests/ut/python/parallel/test_reshape.py index 00b25c1afc..ad83e93cb1 100644 --- a/tests/ut/python/parallel/test_reshape.py +++ b/tests/ut/python/parallel/test_reshape.py @@ -26,7 +26,7 @@ from mindspore.nn.optim.momentum import Momentum from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P -from mindspore.ops.operations.comm_ops import _VirtualDataset +from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell from mindspore.parallel import set_algo_parameters from mindspore.train import Model from mindspore.context import ParallelMode @@ -204,14 +204,12 @@ class GradWrap(nn.Cell): class ReshapeNet1(nn.Cell): def __init__(self, strategy0): super(ReshapeNet1, self).__init__() - self.virtual_dataset = _VirtualDataset() self.reshape = P.Reshape() self.matmul = P.MatMul().shard(strategy0) self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") self.reshape2 = P.Reshape() def construct(self, x): - x = self.virtual_dataset(x) x = self.reshape(x, (256, 25088)) x = self.matmul(x, self.matmul_weight) x = self.reshape2(x, (256 * 256,)) @@ -221,7 +219,6 @@ class ReshapeNet1(nn.Cell): class ReshapeNet2(nn.Cell): def __init__(self, strategy0): super(ReshapeNet2, self).__init__() - self.virtual_dataset = _VirtualDataset() self.reshape = P.Reshape() self.matmul = P.MatMul().shard(strategy0) self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") @@ -230,7 +227,6 @@ class ReshapeNet2(nn.Cell): self.reshape3 = P.Reshape() def construct(self, x): - x = self.virtual_dataset(x) x = self.reshape(x, (256, 25088)) x = self.matmul(x, self.matmul_weight) x = self.reshape2(x, (256 * 256,)) @@ -242,7 +238,6 @@ class ReshapeNet2(nn.Cell): class ReshapeNet3(nn.Cell): def __init__(self, strategy0): super(ReshapeNet3, self).__init__() - self.virtual_dataset = _VirtualDataset() self.reshape = P.Reshape() self.matmul = P.MatMul().shard(strategy0) self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") @@ -251,7 +246,6 @@ class ReshapeNet3(nn.Cell): self.reshape3 = P.Reshape() def construct(self, x): - x = self.virtual_dataset(x) x = self.reshape(x, (256, 25088)) x = self.matmul(x, self.matmul_weight) x = self.reshape2(x, (256 * 256,)) @@ -263,14 +257,12 @@ class ReshapeNet3(nn.Cell): class ReshapeNet4(nn.Cell): def __init__(self, strategy0): super(ReshapeNet4, self).__init__() - self.virtual_dataset = _VirtualDataset() self.reshape = P.Reshape() self.reshape2 = P.Reshape() self.matmul = P.MatMul().shard(strategy0) self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") def construct(self, x): - x = self.virtual_dataset(x) x = self.reshape(x, (256, 25088)) w = self.reshape2(self.matmul_weight, (25088, 256)) x = self.matmul(x, w) @@ -280,14 +272,12 @@ class ReshapeNet4(nn.Cell): class ReshapeNet5(nn.Cell): def __init__(self, strategy0): super(ReshapeNet5, self).__init__() - self.virtual_dataset = _VirtualDataset() self.reshape = P.Reshape() self.matmul1 = P.MatMul().shard(strategy0) self.matmul1_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") self.matmul2 = P.MatMul().shard(strategy0) def construct(self, x): - x = self.virtual_dataset(x) x = self.reshape(x, (256, 25088)) matmul1_o = self.matmul1(x, self.matmul1_weight) matmul2_o = self.matmul2(matmul1_o, x) @@ -297,7 +287,6 @@ class ReshapeNet5(nn.Cell): class ReshapeNet6(nn.Cell): def __init__(self, strategy0): super(ReshapeNet6, self).__init__() - self.virtual_dataset = _VirtualDataset() self.reshape = P.Reshape() self.matmul1_1 = P.MatMul().shard(strategy0) self.matmul1_2 = P.MatMul().shard(strategy0) @@ -306,7 +295,6 @@ class ReshapeNet6(nn.Cell): self.add = P.TensorAdd() def construct(self, x): - x = self.virtual_dataset(x) x = self.reshape(x, (256, 25088)) matmul1_1_o = self.matmul1_1(x, self.matmul1_weight) matmul1_2_o = self.matmul1_2(x, self.matmul1_weight) @@ -334,32 +322,32 @@ def reshape_net2(backbone): def test_reshape_net1_1(): - reshape_net2(ReshapeNet1(((1, 8), (8, 1)))) + reshape_net2(_VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 1))))) def test_reshape_net1_2(): - reshape_net2(ReshapeNet1(((1, 8), (8, 2)))) + reshape_net2(_VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 2))))) def test_reshape_net2_1(): - reshape_net2(ReshapeNet2(((1, 8), (8, 1)))) + reshape_net2(_VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 1))))) def test_reshape_net2_2(): - reshape_net2(ReshapeNet2(((1, 8), (8, 2)))) + reshape_net2(_VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 2))))) def test_reshape_net3_1(): - reshape_net2(ReshapeNet3(((1, 8), (8, 1)))) + reshape_net2(_VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 1))))) def test_reshape_net3_2(): - reshape_net2(ReshapeNet3(((1, 8), (8, 2)))) + reshape_net2(_VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 2))))) def test_reshape_net4_1(): try: - reshape_net2(ReshapeNet4(((1, 8), (8, 1)))) + reshape_net2(_VirtualDatasetCell(ReshapeNet4(((1, 8), (8, 1))))) except ValueError: pass except TypeError: @@ -370,7 +358,7 @@ def test_reshape_net4_1(): def test_reshape_net4_2(): try: - reshape_net2(ReshapeNet4(((1, 8), (8, 2)))) + reshape_net2(_VirtualDatasetCell(ReshapeNet4(((1, 8), (8, 2))))) except ValueError: pass except TypeError: @@ -380,19 +368,19 @@ def test_reshape_net4_2(): def test_reshape_net5_1(): - reshape_net2(ReshapeNet5(((1, 8), (8, 1)))) + reshape_net2(_VirtualDatasetCell(ReshapeNet5(((1, 8), (8, 1))))) def test_reshape_net5_2(): - reshape_net2(ReshapeNet5(((1, 8), (8, 2)))) + reshape_net2(_VirtualDatasetCell(ReshapeNet5(((1, 8), (8, 2))))) def test_reshape_net6_1(): - reshape_net2(ReshapeNet6(((1, 8), (8, 1)))) + reshape_net2(_VirtualDatasetCell(ReshapeNet6(((1, 8), (8, 1))))) def test_reshape_net6_2(): - reshape_net2(ReshapeNet6(((1, 8), (8, 2)))) + reshape_net2(_VirtualDatasetCell(ReshapeNet6(((1, 8), (8, 2))))) class TrainOneStepCell(nn.Cell): @@ -453,39 +441,37 @@ def reshape_common2(parallel_mode, net): def test_reshape_common2_0(): - reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet1(((1, 8), (8, 1)))) + reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 1))))) def test_reshape_common2_1(): - reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet1(((1, 8), (8, 2)))) + reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 2))))) def test_reshape_common2_2(): - reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet2(((1, 8), (8, 1)))) + reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 1))))) def test_reshape_common2_3(): - reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet2(((1, 8), (8, 2)))) + reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 2))))) def test_reshape_common2_4(): - reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet3(((1, 8), (8, 1)))) + reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 1))))) def test_reshape_common2_5(): - reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet3(((1, 8), (8, 2)))) + reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 2))))) class BatchNormReshapeNet(nn.Cell): def __init__(self): super(BatchNormReshapeNet, self).__init__() - self.vd = P._VirtualDataset() self.batch_norm = nn.BatchNorm1d(512, affine=False) self.reshape = P.Reshape() self.prelu = nn.PReLU(channel=256) def construct(self, x): - x = self.vd(x) x = self.batch_norm(x) x = self.reshape(x, (512, 256)) x = self.prelu(x) @@ -499,7 +485,7 @@ def test_batchnorm_reshape_train(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") input_ = Tensor(np.ones([batch_size * device_num, 512]).astype(np.float32) * 0.01) - net = GradWrap(NetWithLoss(BatchNormReshapeNet())) + net = GradWrap(NetWithLoss(_VirtualDatasetCell(BatchNormReshapeNet()))) compile_net(net, input_)