Browse Source

!12173 fix the example of MultiFieldEmbeddingLookup operator.

From: @wangshuide2020
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
9b7a3c0ec2
2 changed files with 3 additions and 3 deletions
  1. +1
    -1
      mindspore/nn/layer/embedding.py
  2. +2
    -2
      mindspore/ops/operations/nn_ops.py

+ 1
- 1
mindspore/nn/layer/embedding.py View File

@@ -450,7 +450,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
>>> input_indices = Tensor([[2, 4, 6, 0, 0], [1, 3, 5, 0, 0]], mindspore.int32)
>>> input_values = Tensor([[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], mindspore.float32)
>>> field_ids = Tensor([[0, 1, 1, 0, 0], [0, 0, 1, 0, 0]], mindspore.int32)
>>> net = nn.MultiFieldEmbeddingLookup(10, 2, field_size=2, operator='SUM')
>>> net = nn.MultiFieldEmbeddingLookup(10, 2, field_size=2, operator='SUM', target='DEVICE')
>>> out = net(input_indices, input_values, field_ids)
>>> print(out.shape)
(2, 2, 2)


+ 2
- 2
mindspore/ops/operations/nn_ops.py View File

@@ -2328,8 +2328,8 @@ class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):

Raises:
TypeError: If `is_grad` is not a bool.
TypeError: If dtype of `logits' is neither float16 nor float32.
TypeError: If dtype of `labels' is neither int32 nor int64.
TypeError: If dtype of `logits` is neither float16 nor float32.
TypeError: If dtype of `labels` is neither int32 nor int64.
ValueError: If logits_shape[0] != labels_shape[0].

Supported Platforms:


Loading…
Cancel
Save