Browse Source

!2167 Add a callback module to avoid the size of the callback.py file too large

Merge pull request !2167 from ougongchang/adjust_callback
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
bb622877e8
9 changed files with 29 additions and 12 deletions
  1. +1
    -1
      example/resnet50_imagenet2012_THOR/model/model_thor.py
  2. +1
    -1
      mindspore/ccsrc/utils/callbacks.cc
  3. +1
    -1
      mindspore/ccsrc/utils/callbacks_ge.cc
  4. +20
    -0
      mindspore/train/callback/__init__.py
  5. +1
    -4
      mindspore/train/callback/callback.py
  6. +1
    -1
      mindspore/train/model.py
  7. +1
    -1
      tests/st/networks/models/resnet50/src_thor/model_thor.py
  8. +2
    -2
      tests/ut/python/utils/test_callback.py
  9. +1
    -1
      tests/ut/python/utils/test_serialize.py

+ 1
- 1
example/resnet50_imagenet2012_THOR/model/model_thor.py View File

@@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp
from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
from mindspore.train.parallel_utils import ParallelMode

from model.dataset_helper import DatasetHelper


+ 1
- 1
mindspore/ccsrc/utils/callbacks.cc View File

@@ -26,7 +26,7 @@

namespace mindspore {
namespace callbacks {
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback";
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
const char kSummary[] = "Summary";


+ 1
- 1
mindspore/ccsrc/utils/callbacks_ge.cc View File

@@ -25,7 +25,7 @@

namespace mindspore {
namespace callbacks {
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback";
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
const char kSummary[] = "Summary";


+ 20
- 0
mindspore/train/callback/__init__.py View File

@@ -0,0 +1,20 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Callback related classes and functions."""

from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryStep, CheckpointConfig, RunContext

__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
"SummaryStep", "CheckpointConfig", "RunContext"]

mindspore/train/callback.py → mindspore/train/callback/callback.py View File

@@ -26,10 +26,7 @@ from mindspore.train._utils import _make_directory
from mindspore import log as logger
from mindspore._checkparam import check_int_non_negative, check_bool
from mindspore.common.tensor import Tensor
from .summary.summary_record import _cache_summary_tensor_data


__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"]
from mindspore.train.summary.summary_record import _cache_summary_tensor_data


_cur_dir = os.getcwd()

+ 1
- 1
mindspore/train/model.py View File

@@ -19,7 +19,7 @@ from mindspore import log as logger
from ..common.tensor import Tensor
from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
from .callback import _InternalCallbackParam, RunContext, _build_callbacks
from .callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check


+ 1
- 1
tests/st/networks/models/resnet50/src_thor/model_thor.py View File

@@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp
from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
from mindspore.train.parallel_utils import ParallelMode

from .dataset_helper import DatasetHelper


+ 2
- 2
tests/ut/python/utils/test_callback.py View File

@@ -25,8 +25,8 @@ from mindspore.common.api import ms_function
from mindspore.common.tensor import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.train.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, _checkpoint_cb_for_save_op, \
LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \
_checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
_build_callbacks, CheckpointConfig, _set_cur_net




+ 1
- 1
tests/ut/python/utils/test_serialize.py View File

@@ -28,7 +28,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.train.callback import _CheckpointManager
from mindspore.train.callback.callback import _CheckpointManager
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
_exec_save_checkpoint, export, _save_graph
from ..ut_filter import non_graph_engine


Loading…
Cancel
Save