|
|
|
@@ -181,8 +181,8 @@ class EmbeddingLookup(Cell): |
|
|
|
or None. Default: None |
|
|
|
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. |
|
|
|
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in |
|
|
|
'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size. |
|
|
|
In addition, it should be noted that it will cost the 'DEVICE' |
|
|
|
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding |
|
|
|
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE' |
|
|
|
memory, so suggests setting a reasonable value to avoid insufficient memory. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
@@ -303,35 +303,13 @@ class EmbeddingLookup(Cell): |
|
|
|
raise ValueError("For '{}', the 'slice_mode' must be in {}, " |
|
|
|
"but got \"{}\".".format(self.cls_name, support_mode, slice_mode)) |
|
|
|
if self.cache_enable and not enable_ps: |
|
|
|
if parallel_mode != ParallelMode.STAND_ALONE: |
|
|
|
raise ValueError(f"For '{self.cls_name}', parallel mode haven't supported cache enable yet.") |
|
|
|
self._set_cache_enable() |
|
|
|
raise ValueError(f"For '{self.cls_name}', haven't supported cache enable for not ps mode.") |
|
|
|
self.embedding_table.unique = self.forward_unique |
|
|
|
self.max_norm = max_norm |
|
|
|
if self.max_norm is not None: |
|
|
|
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) |
|
|
|
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) |
|
|
|
|
|
|
|
def _set_cache_enable(self): |
|
|
|
"""EmbeddingLookup cache check for not ps env, which is only support 'ascend'.""" |
|
|
|
if self.target != 'DEVICE': |
|
|
|
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only " |
|
|
|
f"when 'target' is 'DEVICE', but got 'target': {self.target}") |
|
|
|
if not self.sparse: |
|
|
|
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only " |
|
|
|
f"when 'sparse' is true, but got 'sparse': {self.sparse}.") |
|
|
|
if context.get_context("device_target") != 'Ascend': |
|
|
|
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only " |
|
|
|
f"when device target is 'Ascend', but got {context.get_context('device_target')}.") |
|
|
|
|
|
|
|
logger.info("EmbeddingLookup cache enable takes effect.") |
|
|
|
self.forward_unique = True |
|
|
|
self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU') |
|
|
|
self.unique.add_prim_attr('cache_enable', True) |
|
|
|
self.embedding_table.cache_enable = self.cache_enable |
|
|
|
self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size) |
|
|
|
self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU') |
|
|
|
|
|
|
|
def _process_vocab_cache(self, slice_mode): |
|
|
|
"""PS embeddingLookup cache check and process.""" |
|
|
|
self.cache_enable = False |
|
|
|
|