diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index 3f36ea5c18..66603d5fbf 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -189,6 +189,7 @@ class EmbeddingLookup(Cell): target='CPU', slice_mode='batch_slice', manual_shapes=None, max_norm=None, sparse=True, vocab_cache_size=0): super(EmbeddingLookup, self).__init__() + validator.check_value_type('sparse', sparse, [bool], self.cls_name) self.target = target if target not in ('CPU', 'DEVICE'): raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed ' @@ -200,9 +201,9 @@ class EmbeddingLookup(Cell): else: self.gatherv2 = P.GatherV2() self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') - self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) - self.vocab_cache_size = validator.check_value_type('vocab_cache_size', vocab_cache_size, [int], self.cls_name) - self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) + self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size') + self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size') + self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size') parallel_mode = _get_parallel_mode() is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) self.cache_enable = self.vocab_cache_size > 0 @@ -355,7 +356,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup): slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'): super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target, slice_mode, feature_num_list, max_norm, sparse) - self.field_size = validator.check_value_type('field_size', field_size, [int], self.cls_name) + self.field_size = validator.check_positive_int(field_size, 'field_size') self.operator = operator self.mul = P.Mul() @@ -429,7 +430,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup): batch_size = self.shape(input_indices)[0] num_segments = batch_size * self.field_size bias = Range(0, num_segments, self.field_size)() - bias = self.reshape(bias, (self.field_size, -1)) + bias = self.reshape(bias, (batch_size, -1)) field_ids = self.bias_add(field_ids, bias) if self.target == "CPU": diff --git a/tests/ut/python/parallel/test_multi_field_embedding.py b/tests/ut/python/parallel/test_multi_field_embedding.py index 2ea631ce51..a30a1a7c54 100644 --- a/tests/ut/python/parallel/test_multi_field_embedding.py +++ b/tests/ut/python/parallel/test_multi_field_embedding.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ from mindspore import Tensor, context from mindspore.nn import TrainOneStepCell, Adam from tests.ut.python.ops.test_math_ops import VirtualLoss - grad_all = C.GradOperation(get_all=True) @@ -48,10 +47,11 @@ class NetWithLoss(nn.Cell): class Net(nn.Cell): - def __init__(self, shape, slice_mode=nn.EmbeddingLookup.BATCH_SLICE, target="Device", operator='SUM'): + def __init__(self, shape, field_size=10, slice_mode=nn.EmbeddingLookup.BATCH_SLICE, target="Device", + operator='SUM'): super().__init__() self.embedding = nn.MultiFieldEmbeddingLookup(vocab_size=32, embedding_size=64, target=target, - field_size=shape[1], slice_mode=slice_mode, operator=operator) + field_size=field_size, slice_mode=slice_mode, operator=operator) self.reshape = P.Reshape().shard(((8, 1, 1),)) self.batch_size = shape[0] @@ -77,28 +77,28 @@ def compile_net(net, shape): def test_embeddinglookup_batch_parallel_sum(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 64] - net = NetWithLoss(Net(shape, target='DEVICE')) + net = NetWithLoss(Net(shape, field_size=10, target='DEVICE')) compile_net(net, shape) def test_embeddinglookup_row_parallel_sum(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 64] - net = NetWithLoss(Net(shape, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, target='DEVICE')) + net = NetWithLoss(Net(shape, field_size=9, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, target='DEVICE')) compile_net(net, shape) def test_embeddinglookup_column_parallel_sum(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 64] - net = NetWithLoss(Net(shape, slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, target='DEVICE')) + net = NetWithLoss(Net(shape, field_size=10, slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, target='DEVICE')) compile_net(net, shape) def test_embeddinglookup_batch_parallel_mean(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 64] - net = NetWithLoss(Net(shape, target='DEVICE', operator='MEAN')) + net = NetWithLoss(Net(shape, field_size=1, target='DEVICE', operator='MEAN')) compile_net(net, shape)