|
|
|
@@ -247,15 +247,15 @@ def fc_with_initialize(input_channels, out_channels): |
|
|
|
class BNReshapeDenseBNNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super(BNReshapeDenseBNNet, self).__init__() |
|
|
|
self.batch_norm = bn_with_initialize(512) |
|
|
|
self.batch_norm = bn_with_initialize(2) |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.batch_norm2 = nn.BatchNorm1d(512, affine=False) |
|
|
|
self.fc = fc_with_initialize(512 * 32 * 32, 512) |
|
|
|
self.fc = fc_with_initialize(2 * 32 * 32, 512) |
|
|
|
self.loss = SemiAutoOneHotNet(args=Args(), strategy=StrategyBatch()) |
|
|
|
|
|
|
|
def construct(self, x, label): |
|
|
|
x = self.batch_norm(x) |
|
|
|
x = self.reshape(x, (16, 512*32*32)) |
|
|
|
x = self.reshape(x, (16, 2*32*32)) |
|
|
|
x = self.fc(x) |
|
|
|
x = self.batch_norm2(x) |
|
|
|
loss = self.loss(x, label) |
|
|
|
@@ -266,7 +266,7 @@ def test_bn_reshape_dense_bn_train_loss(): |
|
|
|
batch_size = 16 |
|
|
|
device_num = 16 |
|
|
|
context.set_auto_parallel_context(device_num=device_num, global_rank=0) |
|
|
|
input = Tensor(np.ones([batch_size, 512, 32, 32]).astype(np.float32) * 0.01) |
|
|
|
input = Tensor(np.ones([batch_size, 2, 32, 32]).astype(np.float32) * 0.01) |
|
|
|
label = Tensor(np.ones([batch_size]), dtype=ms.int32) |
|
|
|
|
|
|
|
net = GradWrap(NetWithLoss(BNReshapeDenseBNNet())) |
|
|
|
|