| @@ -189,6 +189,7 @@ class EmbeddingLookup(Cell): | |||||
| target='CPU', slice_mode='batch_slice', manual_shapes=None, | target='CPU', slice_mode='batch_slice', manual_shapes=None, | ||||
| max_norm=None, sparse=True, vocab_cache_size=0): | max_norm=None, sparse=True, vocab_cache_size=0): | ||||
| super(EmbeddingLookup, self).__init__() | super(EmbeddingLookup, self).__init__() | ||||
| validator.check_value_type('sparse', sparse, [bool], self.cls_name) | |||||
| self.target = target | self.target = target | ||||
| if target not in ('CPU', 'DEVICE'): | if target not in ('CPU', 'DEVICE'): | ||||
| raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed ' | raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed ' | ||||
| @@ -200,9 +201,9 @@ class EmbeddingLookup(Cell): | |||||
| else: | else: | ||||
| self.gatherv2 = P.GatherV2() | self.gatherv2 = P.GatherV2() | ||||
| self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | ||||
| self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) | |||||
| self.vocab_cache_size = validator.check_value_type('vocab_cache_size', vocab_cache_size, [int], self.cls_name) | |||||
| self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) | |||||
| self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size') | |||||
| self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size') | |||||
| self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size') | |||||
| parallel_mode = _get_parallel_mode() | parallel_mode = _get_parallel_mode() | ||||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | ||||
| self.cache_enable = self.vocab_cache_size > 0 | self.cache_enable = self.vocab_cache_size > 0 | ||||
| @@ -355,7 +356,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup): | |||||
| slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'): | slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'): | ||||
| super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target, | super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target, | ||||
| slice_mode, feature_num_list, max_norm, sparse) | slice_mode, feature_num_list, max_norm, sparse) | ||||
| self.field_size = validator.check_value_type('field_size', field_size, [int], self.cls_name) | |||||
| self.field_size = validator.check_positive_int(field_size, 'field_size') | |||||
| self.operator = operator | self.operator = operator | ||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| @@ -429,7 +430,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup): | |||||
| batch_size = self.shape(input_indices)[0] | batch_size = self.shape(input_indices)[0] | ||||
| num_segments = batch_size * self.field_size | num_segments = batch_size * self.field_size | ||||
| bias = Range(0, num_segments, self.field_size)() | bias = Range(0, num_segments, self.field_size)() | ||||
| bias = self.reshape(bias, (self.field_size, -1)) | |||||
| bias = self.reshape(bias, (batch_size, -1)) | |||||
| field_ids = self.bias_add(field_ids, bias) | field_ids = self.bias_add(field_ids, bias) | ||||
| if self.target == "CPU": | if self.target == "CPU": | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -23,7 +23,6 @@ from mindspore import Tensor, context | |||||
| from mindspore.nn import TrainOneStepCell, Adam | from mindspore.nn import TrainOneStepCell, Adam | ||||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | from tests.ut.python.ops.test_math_ops import VirtualLoss | ||||
| grad_all = C.GradOperation(get_all=True) | grad_all = C.GradOperation(get_all=True) | ||||
| @@ -48,10 +47,11 @@ class NetWithLoss(nn.Cell): | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self, shape, slice_mode=nn.EmbeddingLookup.BATCH_SLICE, target="Device", operator='SUM'): | |||||
| def __init__(self, shape, field_size=10, slice_mode=nn.EmbeddingLookup.BATCH_SLICE, target="Device", | |||||
| operator='SUM'): | |||||
| super().__init__() | super().__init__() | ||||
| self.embedding = nn.MultiFieldEmbeddingLookup(vocab_size=32, embedding_size=64, target=target, | self.embedding = nn.MultiFieldEmbeddingLookup(vocab_size=32, embedding_size=64, target=target, | ||||
| field_size=shape[1], slice_mode=slice_mode, operator=operator) | |||||
| field_size=field_size, slice_mode=slice_mode, operator=operator) | |||||
| self.reshape = P.Reshape().shard(((8, 1, 1),)) | self.reshape = P.Reshape().shard(((8, 1, 1),)) | ||||
| self.batch_size = shape[0] | self.batch_size = shape[0] | ||||
| @@ -77,28 +77,28 @@ def compile_net(net, shape): | |||||
| def test_embeddinglookup_batch_parallel_sum(): | def test_embeddinglookup_batch_parallel_sum(): | ||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | ||||
| shape = [64, 64] | shape = [64, 64] | ||||
| net = NetWithLoss(Net(shape, target='DEVICE')) | |||||
| net = NetWithLoss(Net(shape, field_size=10, target='DEVICE')) | |||||
| compile_net(net, shape) | compile_net(net, shape) | ||||
| def test_embeddinglookup_row_parallel_sum(): | def test_embeddinglookup_row_parallel_sum(): | ||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | ||||
| shape = [64, 64] | shape = [64, 64] | ||||
| net = NetWithLoss(Net(shape, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, target='DEVICE')) | |||||
| net = NetWithLoss(Net(shape, field_size=9, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, target='DEVICE')) | |||||
| compile_net(net, shape) | compile_net(net, shape) | ||||
| def test_embeddinglookup_column_parallel_sum(): | def test_embeddinglookup_column_parallel_sum(): | ||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | ||||
| shape = [64, 64] | shape = [64, 64] | ||||
| net = NetWithLoss(Net(shape, slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, target='DEVICE')) | |||||
| net = NetWithLoss(Net(shape, field_size=10, slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, target='DEVICE')) | |||||
| compile_net(net, shape) | compile_net(net, shape) | ||||
| def test_embeddinglookup_batch_parallel_mean(): | def test_embeddinglookup_batch_parallel_mean(): | ||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | ||||
| shape = [64, 64] | shape = [64, 64] | ||||
| net = NetWithLoss(Net(shape, target='DEVICE', operator='MEAN')) | |||||
| net = NetWithLoss(Net(shape, field_size=1, target='DEVICE', operator='MEAN')) | |||||
| compile_net(net, shape) | compile_net(net, shape) | ||||