Browse Source

!12870 add mix precision for cache

From: @fangzehua
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
659b912f6d
2 changed files with 10 additions and 1 deletions
  1. +3
    -0
      mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc
  2. +7
    -1
      mindspore/train/callback/_checkpoint.py

+ 3
- 0
mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc View File

@@ -446,6 +446,9 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param
for (size_t i = 0; i < cnodes_size; ++i) {
if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) {
auto load_node = cnodes[i]->input(1);
if (IsPrimitiveCNode(load_node, prim::kPrimCast)) {
load_node = load_node->cast<CNodePtr>()->input(1);
}
if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
auto param_node = load_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>();
if (param_set.find(param_node) != param_set.end()) {


+ 7
- 1
mindspore/train/callback/_checkpoint.py View File

@@ -261,6 +261,7 @@ class ModelCheckpoint(Callback):
self._manager = CheckpointManager()
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
self._graph_saved = False
self._need_flush_from_cache = True

def step_end(self, run_context):
"""
@@ -326,7 +327,8 @@ class ModelCheckpoint(Callback):
return

# if param is cache enable, flush data from cache to host before save_ckpt
self._flush_from_cache(cb_params)
if self._need_flush_from_cache:
self._flush_from_cache(cb_params)

save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
@@ -365,10 +367,14 @@ class ModelCheckpoint(Callback):

def _flush_from_cache(self, cb_params):
"""Flush cache data to host if tensor is cache enable."""
has_cache_params = False
params = cb_params.train_network.get_parameters()
for param in params:
if param.cache_enable:
has_cache_params = True
Tensor(param).flush_from_cache()
if not has_cache_params:
self._need_flush_from_cache = False

@property
def latest_ckpt_file_name(self):


Loading…
Cancel
Save