| @@ -211,7 +211,6 @@ class WideDeepModel(nn.Cell): | |||||
| if config.deep_table_slice_mode == "column_slice": | if config.deep_table_slice_mode == "column_slice": | ||||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, | self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, | ||||
| slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE) | slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE) | ||||
| self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),)) | |||||
| self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),)) | self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),)) | ||||
| self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) | self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) | ||||
| self.dense_layer_1.matmul.add_prim_attr("field_size", self.field_size) | self.dense_layer_1.matmul.add_prim_attr("field_size", self.field_size) | ||||
| @@ -233,7 +232,6 @@ class WideDeepModel(nn.Cell): | |||||
| self.deep_mul.shard(((1, get_group_size(), 1), (1, get_group_size(), 1))) | self.deep_mul.shard(((1, get_group_size(), 1), (1, get_group_size(), 1))) | ||||
| self.wide_mul.shard(((1, get_group_size(), 1), (1, get_group_size(), 1))) | self.wide_mul.shard(((1, get_group_size(), 1), (1, get_group_size(), 1))) | ||||
| self.reduce_sum.shard(((1, get_group_size(), 1),)) | self.reduce_sum.shard(((1, get_group_size(), 1),)) | ||||
| self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),)) | |||||
| self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),)) | self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),)) | ||||
| self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) | self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) | ||||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | self.embedding_table = self.deep_embeddinglookup.embedding_table | ||||