|
|
|
@@ -153,7 +153,7 @@ class TransformerNet(nn.Cell): |
|
|
|
ffn_hidden_size=64, |
|
|
|
moe_config=moe_config, |
|
|
|
parallel_config=parallel_config) |
|
|
|
self.loss = CrossEntropyLoss(parallel_config=config.moe_parallel_config) |
|
|
|
self.loss = CrossEntropyLoss(parallel_config=parallel_config.moe_parallel_config) |
|
|
|
|
|
|
|
def construct(self, x1, x2, x3, x4, x5, y, mask): |
|
|
|
predict, _, _ = self.network(x1, x2, x3, x4, x5) |
|
|
|
|