| @@ -188,7 +188,7 @@ class WideDeepModel(nn.Cell): | |||||
| self.deep_layer_act, | self.deep_layer_act, | ||||
| use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) | use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) | ||||
| self.embeddinglookup = nn.EmbeddingLookup() | |||||
| self.embeddinglookup = nn.EmbeddingLookup(target='DEVICE') | |||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | self.reduce_sum = P.ReduceSum(keep_dims=False) | ||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| @@ -206,11 +206,11 @@ class WideDeepModel(nn.Cell): | |||||
| """ | """ | ||||
| mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | ||||
| # Wide layer | # Wide layer | ||||
| wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr, 0) | |||||
| wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr) | |||||
| wx = self.mul(wide_id_weight, mask) | wx = self.mul(wide_id_weight, mask) | ||||
| wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) | wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) | ||||
| # Deep layer | # Deep layer | ||||
| deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr, 0) | |||||
| deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr) | |||||
| vx = self.mul(deep_id_embs, mask) | vx = self.mul(deep_id_embs, mask) | ||||
| deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim)) | deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim)) | ||||
| deep_in = self.dense_layer_1(deep_in) | deep_in = self.dense_layer_1(deep_in) | ||||