|
|
|
@@ -26,7 +26,7 @@ from mindspore.nn.optim.momentum import Momentum |
|
|
|
from mindspore.ops import composite as C |
|
|
|
from mindspore.ops import functional as F |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.ops.operations.comm_ops import _VirtualDataset |
|
|
|
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell |
|
|
|
from mindspore.parallel import set_algo_parameters |
|
|
|
from mindspore.train import Model |
|
|
|
from mindspore.context import ParallelMode |
|
|
|
@@ -204,14 +204,12 @@ class GradWrap(nn.Cell): |
|
|
|
class ReshapeNet1(nn.Cell): |
|
|
|
def __init__(self, strategy0): |
|
|
|
super(ReshapeNet1, self).__init__() |
|
|
|
self.virtual_dataset = _VirtualDataset() |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.matmul = P.MatMul().shard(strategy0) |
|
|
|
self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") |
|
|
|
self.reshape2 = P.Reshape() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.virtual_dataset(x) |
|
|
|
x = self.reshape(x, (256, 25088)) |
|
|
|
x = self.matmul(x, self.matmul_weight) |
|
|
|
x = self.reshape2(x, (256 * 256,)) |
|
|
|
@@ -221,7 +219,6 @@ class ReshapeNet1(nn.Cell): |
|
|
|
class ReshapeNet2(nn.Cell): |
|
|
|
def __init__(self, strategy0): |
|
|
|
super(ReshapeNet2, self).__init__() |
|
|
|
self.virtual_dataset = _VirtualDataset() |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.matmul = P.MatMul().shard(strategy0) |
|
|
|
self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") |
|
|
|
@@ -230,7 +227,6 @@ class ReshapeNet2(nn.Cell): |
|
|
|
self.reshape3 = P.Reshape() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.virtual_dataset(x) |
|
|
|
x = self.reshape(x, (256, 25088)) |
|
|
|
x = self.matmul(x, self.matmul_weight) |
|
|
|
x = self.reshape2(x, (256 * 256,)) |
|
|
|
@@ -242,7 +238,6 @@ class ReshapeNet2(nn.Cell): |
|
|
|
class ReshapeNet3(nn.Cell): |
|
|
|
def __init__(self, strategy0): |
|
|
|
super(ReshapeNet3, self).__init__() |
|
|
|
self.virtual_dataset = _VirtualDataset() |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.matmul = P.MatMul().shard(strategy0) |
|
|
|
self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") |
|
|
|
@@ -251,7 +246,6 @@ class ReshapeNet3(nn.Cell): |
|
|
|
self.reshape3 = P.Reshape() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.virtual_dataset(x) |
|
|
|
x = self.reshape(x, (256, 25088)) |
|
|
|
x = self.matmul(x, self.matmul_weight) |
|
|
|
x = self.reshape2(x, (256 * 256,)) |
|
|
|
@@ -263,14 +257,12 @@ class ReshapeNet3(nn.Cell): |
|
|
|
class ReshapeNet4(nn.Cell): |
|
|
|
def __init__(self, strategy0): |
|
|
|
super(ReshapeNet4, self).__init__() |
|
|
|
self.virtual_dataset = _VirtualDataset() |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.reshape2 = P.Reshape() |
|
|
|
self.matmul = P.MatMul().shard(strategy0) |
|
|
|
self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.virtual_dataset(x) |
|
|
|
x = self.reshape(x, (256, 25088)) |
|
|
|
w = self.reshape2(self.matmul_weight, (25088, 256)) |
|
|
|
x = self.matmul(x, w) |
|
|
|
@@ -280,14 +272,12 @@ class ReshapeNet4(nn.Cell): |
|
|
|
class ReshapeNet5(nn.Cell): |
|
|
|
def __init__(self, strategy0): |
|
|
|
super(ReshapeNet5, self).__init__() |
|
|
|
self.virtual_dataset = _VirtualDataset() |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.matmul1 = P.MatMul().shard(strategy0) |
|
|
|
self.matmul1_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight") |
|
|
|
self.matmul2 = P.MatMul().shard(strategy0) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.virtual_dataset(x) |
|
|
|
x = self.reshape(x, (256, 25088)) |
|
|
|
matmul1_o = self.matmul1(x, self.matmul1_weight) |
|
|
|
matmul2_o = self.matmul2(matmul1_o, x) |
|
|
|
@@ -297,7 +287,6 @@ class ReshapeNet5(nn.Cell): |
|
|
|
class ReshapeNet6(nn.Cell): |
|
|
|
def __init__(self, strategy0): |
|
|
|
super(ReshapeNet6, self).__init__() |
|
|
|
self.virtual_dataset = _VirtualDataset() |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.matmul1_1 = P.MatMul().shard(strategy0) |
|
|
|
self.matmul1_2 = P.MatMul().shard(strategy0) |
|
|
|
@@ -306,7 +295,6 @@ class ReshapeNet6(nn.Cell): |
|
|
|
self.add = P.TensorAdd() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.virtual_dataset(x) |
|
|
|
x = self.reshape(x, (256, 25088)) |
|
|
|
matmul1_1_o = self.matmul1_1(x, self.matmul1_weight) |
|
|
|
matmul1_2_o = self.matmul1_2(x, self.matmul1_weight) |
|
|
|
@@ -334,32 +322,32 @@ def reshape_net2(backbone): |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_net1_1(): |
|
|
|
reshape_net2(ReshapeNet1(((1, 8), (8, 1)))) |
|
|
|
reshape_net2(_VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 1))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_net1_2(): |
|
|
|
reshape_net2(ReshapeNet1(((1, 8), (8, 2)))) |
|
|
|
reshape_net2(_VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 2))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_net2_1(): |
|
|
|
reshape_net2(ReshapeNet2(((1, 8), (8, 1)))) |
|
|
|
reshape_net2(_VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 1))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_net2_2(): |
|
|
|
reshape_net2(ReshapeNet2(((1, 8), (8, 2)))) |
|
|
|
reshape_net2(_VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 2))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_net3_1(): |
|
|
|
reshape_net2(ReshapeNet3(((1, 8), (8, 1)))) |
|
|
|
reshape_net2(_VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 1))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_net3_2(): |
|
|
|
reshape_net2(ReshapeNet3(((1, 8), (8, 2)))) |
|
|
|
reshape_net2(_VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 2))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_net4_1(): |
|
|
|
try: |
|
|
|
reshape_net2(ReshapeNet4(((1, 8), (8, 1)))) |
|
|
|
reshape_net2(_VirtualDatasetCell(ReshapeNet4(((1, 8), (8, 1))))) |
|
|
|
except ValueError: |
|
|
|
pass |
|
|
|
except TypeError: |
|
|
|
@@ -370,7 +358,7 @@ def test_reshape_net4_1(): |
|
|
|
|
|
|
|
def test_reshape_net4_2(): |
|
|
|
try: |
|
|
|
reshape_net2(ReshapeNet4(((1, 8), (8, 2)))) |
|
|
|
reshape_net2(_VirtualDatasetCell(ReshapeNet4(((1, 8), (8, 2))))) |
|
|
|
except ValueError: |
|
|
|
pass |
|
|
|
except TypeError: |
|
|
|
@@ -380,19 +368,19 @@ def test_reshape_net4_2(): |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_net5_1(): |
|
|
|
reshape_net2(ReshapeNet5(((1, 8), (8, 1)))) |
|
|
|
reshape_net2(_VirtualDatasetCell(ReshapeNet5(((1, 8), (8, 1))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_net5_2(): |
|
|
|
reshape_net2(ReshapeNet5(((1, 8), (8, 2)))) |
|
|
|
reshape_net2(_VirtualDatasetCell(ReshapeNet5(((1, 8), (8, 2))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_net6_1(): |
|
|
|
reshape_net2(ReshapeNet6(((1, 8), (8, 1)))) |
|
|
|
reshape_net2(_VirtualDatasetCell(ReshapeNet6(((1, 8), (8, 1))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_net6_2(): |
|
|
|
reshape_net2(ReshapeNet6(((1, 8), (8, 2)))) |
|
|
|
reshape_net2(_VirtualDatasetCell(ReshapeNet6(((1, 8), (8, 2))))) |
|
|
|
|
|
|
|
|
|
|
|
class TrainOneStepCell(nn.Cell): |
|
|
|
@@ -453,39 +441,37 @@ def reshape_common2(parallel_mode, net): |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_common2_0(): |
|
|
|
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet1(((1, 8), (8, 1)))) |
|
|
|
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 1))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_common2_1(): |
|
|
|
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet1(((1, 8), (8, 2)))) |
|
|
|
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet1(((1, 8), (8, 2))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_common2_2(): |
|
|
|
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet2(((1, 8), (8, 1)))) |
|
|
|
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 1))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_common2_3(): |
|
|
|
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet2(((1, 8), (8, 2)))) |
|
|
|
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet2(((1, 8), (8, 2))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_common2_4(): |
|
|
|
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet3(((1, 8), (8, 1)))) |
|
|
|
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 1))))) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_common2_5(): |
|
|
|
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet3(((1, 8), (8, 2)))) |
|
|
|
reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, _VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 2))))) |
|
|
|
|
|
|
|
|
|
|
|
class BatchNormReshapeNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super(BatchNormReshapeNet, self).__init__() |
|
|
|
self.vd = P._VirtualDataset() |
|
|
|
self.batch_norm = nn.BatchNorm1d(512, affine=False) |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.prelu = nn.PReLU(channel=256) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.vd(x) |
|
|
|
x = self.batch_norm(x) |
|
|
|
x = self.reshape(x, (512, 256)) |
|
|
|
x = self.prelu(x) |
|
|
|
@@ -499,7 +485,7 @@ def test_batchnorm_reshape_train(): |
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") |
|
|
|
input_ = Tensor(np.ones([batch_size * device_num, 512]).astype(np.float32) * 0.01) |
|
|
|
|
|
|
|
net = GradWrap(NetWithLoss(BatchNormReshapeNet())) |
|
|
|
net = GradWrap(NetWithLoss(_VirtualDatasetCell(BatchNormReshapeNet()))) |
|
|
|
|
|
|
|
compile_net(net, input_) |
|
|
|
|
|
|
|
|