From 00e05f7c34b96e16ab877be452d83bcf85361349 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Tue, 2 Jun 2020 14:32:20 +0800 Subject: [PATCH] fixed doc for merge fixed Embedding --- mindspore/nn/layer/embedding.py | 8 +++++--- mindspore/ops/operations/control_ops.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index 5df38b6845..e27cd765af 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -44,10 +44,11 @@ class Embedding(Cell): dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32. Inputs: - - **input** (Tensor) - Tensor of shape :math:`(\text{vocab_size})`. - + - **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The element of + the Tensor should be integer and not larger than vocab_size. else the corresponding embedding vector is zero + if larger than vocab_size. Outputs: - Tensor of shape :math:`(\text{vocab_size}, \text{embedding_size})`. + Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`. Examples: >>> net = nn.Embedding(20000, 768, True) @@ -61,6 +62,7 @@ class Embedding(Cell): def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32): super(Embedding, self).__init__() validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) + validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name) self.vocab_size = vocab_size self.embedding_size = embedding_size self.use_one_hot = use_one_hot diff --git a/mindspore/ops/operations/control_ops.py b/mindspore/ops/operations/control_ops.py index 2c804c483f..e7ac4572ce 100644 --- a/mindspore/ops/operations/control_ops.py +++ b/mindspore/ops/operations/control_ops.py @@ -144,7 +144,7 @@ class Merge(PrimitiveWithInfer): One and only one of the inputs should be selected as the output Inputs: - - **inputs** (Tuple) - The data to be merged. All tuple elements should have same data type. + - **inputs** (Union(Tuple, List)) - The data to be merged. All tuple elements should have same data type. Outputs: tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element.