Browse Source

!8420 mode_train_file

From: @bai-yangfan
Reviewed-by: @chenfei52,@zh_qh
Signed-off-by: @zh_qh
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
623df51e06
25 changed files with 28 additions and 55 deletions
  1. +7
    -1
      mindspore/train/__init__.py
  2. +0
    -2
      mindspore/train/amp.py
  3. +0
    -2
      mindspore/train/loss_scale_manager.py
  4. +0
    -2
      mindspore/train/serialization.py
  5. +1
    -3
      model_zoo/official/cv/alexnet/export.py
  6. +1
    -2
      model_zoo/official/cv/cnnctc/export.py
  7. +1
    -4
      model_zoo/official/cv/deeplabv3/export.py
  8. +1
    -2
      model_zoo/official/cv/faster_rcnn/export.py
  9. +1
    -2
      model_zoo/official/cv/googlenet/export.py
  10. +1
    -2
      model_zoo/official/cv/inceptionv3/export.py
  11. +1
    -3
      model_zoo/official/cv/lenet/export.py
  12. +1
    -3
      model_zoo/official/cv/lenet_quant/export.py
  13. +1
    -2
      model_zoo/official/cv/maskrcnn/export.py
  14. +1
    -2
      model_zoo/official/cv/mobilenetv2/export.py
  15. +1
    -3
      model_zoo/official/cv/mobilenetv2_quant/export.py
  16. +1
    -2
      model_zoo/official/cv/mobilenetv3/export.py
  17. +1
    -2
      model_zoo/official/cv/nasnet/export.py
  18. +1
    -2
      model_zoo/official/cv/psenet/export.py
  19. +1
    -2
      model_zoo/official/cv/resnet/export.py
  20. +1
    -3
      model_zoo/official/cv/resnet_thor/export.py
  21. +1
    -2
      model_zoo/official/cv/resnext50/export.py
  22. +1
    -2
      model_zoo/official/cv/ssd/export.py
  23. +1
    -2
      model_zoo/official/cv/unet/export.py
  24. +1
    -2
      model_zoo/official/cv/warpctc/export.py
  25. +1
    -1
      tests/st/quantization/lenet_quant/test_lenet_quant.py

+ 7
- 1
mindspore/train/__init__.py View File

@@ -20,5 +20,11 @@ Helper functions in train piplines.
from .model import Model from .model import Model
from .dataset_helper import DatasetHelper, connect_network_with_dataset from .dataset_helper import DatasetHelper, connect_network_with_dataset
from . import amp from . import amp
from .amp import build_train_network
from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, parse_print,\
build_searched_strategy, merge_sliced_parameter


__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset"]
__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
"load_param_into_net", "export", "parse_print", "build_searched_strategy", "merge_sliced_parameter"]

+ 0
- 2
mindspore/train/amp.py View File

@@ -26,8 +26,6 @@ from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from ..context import ParallelMode from ..context import ParallelMode
from .. import context from .. import context


__all__ = ["build_train_network"]



class OutputTo16(nn.Cell): class OutputTo16(nn.Cell):
"Wrap cell for amp. Cast network output back to float16" "Wrap cell for amp. Cast network output back to float16"


+ 0
- 2
mindspore/train/loss_scale_manager.py View File

@@ -17,8 +17,6 @@
from .._checkparam import Validator as validator from .._checkparam import Validator as validator
from .. import nn from .. import nn


__all__ = ["LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager"]



class LossScaleManager: class LossScaleManager:
"""Loss scale manager abstract class.""" """Loss scale manager abstract class."""


+ 0
- 2
mindspore/train/serialization.py View File

@@ -33,8 +33,6 @@ from mindspore._checkparam import check_input_data, Validator
from mindspore.compression.export import quant_export from mindspore.compression.export import quant_export
import mindspore.context as context import mindspore.context as context


__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
"build_searched_strategy", "merge_sliced_parameter"]


tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64, "Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,


+ 1
- 3
model_zoo/official/cv/alexnet/export.py View File

@@ -20,9 +20,7 @@ import argparse
import numpy as np import numpy as np


import mindspore as ms import mindspore as ms
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export


from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
from src.alexnet import AlexNet from src.alexnet import AlexNet


+ 1
- 2
model_zoo/official/cv/cnnctc/export.py View File

@@ -16,9 +16,8 @@
import argparse import argparse
import numpy as np import numpy as np


from mindspore import Tensor, context
from mindspore import Tensor, context, load_checkpoint, export
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.train.serialization import load_checkpoint, export


from src.config import Config_CNNCTC from src.config import Config_CNNCTC
from src.cnn_ctc import CNNCTC_Model from src.cnn_ctc import CNNCTC_Model


+ 1
- 4
model_zoo/official/cv/deeplabv3/export.py View File

@@ -16,10 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np


from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import export
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export


from src.nets import net_factory from src.nets import net_factory




+ 1
- 2
model_zoo/official/cv/faster_rcnn/export.py View File

@@ -17,8 +17,7 @@ import argparse
import numpy as np import numpy as np


import mindspore as ms import mindspore as ms
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, load_checkpoint, load_param_into_net, export


from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
from src.config import config from src.config import config


+ 1
- 2
model_zoo/official/cv/googlenet/export.py View File

@@ -20,8 +20,7 @@ import argparse
import numpy as np import numpy as np


import mindspore as ms import mindspore as ms
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, load_checkpoint, load_param_into_net, export


from src.config import cifar_cfg, imagenet_cfg from src.config import cifar_cfg, imagenet_cfg
from src.googlenet import GoogleNet from src.googlenet import GoogleNet


+ 1
- 2
model_zoo/official/cv/inceptionv3/export.py View File

@@ -19,8 +19,7 @@ import argparse
import numpy as np import numpy as np


import mindspore as ms import mindspore as ms
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, load_checkpoint, load_param_into_net, export


from src.config import config_gpu as cfg from src.config import config_gpu as cfg
from src.inception_v3 import InceptionV3 from src.inception_v3 import InceptionV3


+ 1
- 3
model_zoo/official/cv/lenet/export.py View File

@@ -20,9 +20,7 @@ import argparse
import numpy as np import numpy as np


import mindspore import mindspore
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export


from src.config import mnist_cfg as cfg from src.config import mnist_cfg as cfg
from src.lenet import LeNet5 from src.lenet import LeNet5


+ 1
- 3
model_zoo/official/cv/lenet_quant/export.py View File

@@ -20,10 +20,8 @@ import argparse
import numpy as np import numpy as np


import mindspore import mindspore
from mindspore import Tensor
from mindspore import context
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore.compression.quant import QuantizationAwareTraining from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export


from src.config import mnist_cfg as cfg from src.config import mnist_cfg as cfg
from src.lenet_fusion import LeNet5 as LeNet5Fusion from src.lenet_fusion import LeNet5 as LeNet5Fusion


+ 1
- 2
model_zoo/official/cv/maskrcnn/export.py View File

@@ -16,8 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np


from mindspore import Tensor, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export


from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.config import config from src.config import config


+ 1
- 2
model_zoo/official/cv/mobilenetv2/export.py View File

@@ -16,8 +16,7 @@
mobilenetv2 export mindir. mobilenetv2 export mindir.
""" """
import numpy as np import numpy as np
from mindspore import Tensor
from mindspore.train.serialization import export
from mindspore import Tensor, export
from src.config import set_config from src.config import set_config
from src.args import export_parse_args from src.args import export_parse_args
from src.models import define_net, load_ckpt from src.models import define_net, load_ckpt


+ 1
- 3
model_zoo/official/cv/mobilenetv2_quant/export.py View File

@@ -18,9 +18,7 @@ import argparse
import numpy as np import numpy as np


import mindspore import mindspore
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore.compression.quant import QuantizationAwareTraining from mindspore.compression.quant import QuantizationAwareTraining


from src.mobilenetV2 import mobilenetV2 from src.mobilenetV2 import mobilenetV2


+ 1
- 2
model_zoo/official/cv/mobilenetv3/export.py View File

@@ -17,8 +17,7 @@ mobilenetv3 export mindir.
""" """
import argparse import argparse
import numpy as np import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.config import config_gpu from src.config import config_gpu
from src.mobilenetV3 import mobilenet_v3_large from src.mobilenetV3 import mobilenet_v3_large




+ 1
- 2
model_zoo/official/cv/nasnet/export.py View File

@@ -19,8 +19,7 @@ import argparse
import numpy as np import numpy as np


import mindspore as ms import mindspore as ms
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, load_checkpoint, load_param_into_net, export


from src.config import nasnet_a_mobile_config_gpu as cfg from src.config import nasnet_a_mobile_config_gpu as cfg
from src.nasnet_a_mobile import NASNetAMobile from src.nasnet_a_mobile import NASNetAMobile


+ 1
- 2
model_zoo/official/cv/psenet/export.py View File

@@ -19,8 +19,7 @@ import argparse
import numpy as np import numpy as np


import mindspore as ms import mindspore as ms
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, load_checkpoint, load_param_into_net, export


from src.config import config from src.config import config
from src.ETSNET.etsnet import ETSNet from src.ETSNET.etsnet import ETSNet


+ 1
- 2
model_zoo/official/cv/resnet/export.py View File

@@ -19,8 +19,7 @@ python export.py
import argparse import argparse
import numpy as np import numpy as np


from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, load_checkpoint, load_param_into_net, export


if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='resnet export') parser = argparse.ArgumentParser(description='resnet export')


+ 1
- 3
model_zoo/official/cv/resnet_thor/export.py View File

@@ -16,9 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np


from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.resnet_thor import resnet50 as resnet from src.resnet_thor import resnet50 as resnet
from src.config import config from src.config import config




+ 1
- 2
model_zoo/official/cv/resnext50/export.py View File

@@ -17,8 +17,7 @@ resnext export mindir.
""" """
import argparse import argparse
import numpy as np import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.config import config from src.config import config
from src.image_classification import get_network from src.image_classification import get_network




+ 1
- 2
model_zoo/official/cv/ssd/export.py View File

@@ -17,8 +17,7 @@ ssd export mindir.
""" """
import argparse import argparse
import numpy as np import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.ssd import SSD300, ssd_mobilenet_v2 from src.ssd import SSD300, ssd_mobilenet_v2
from src.config import config from src.config import config




+ 1
- 2
model_zoo/official/cv/unet/export.py View File

@@ -16,8 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np


from mindspore import Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from mindspore import Tensor, export, load_checkpoint, load_param_into_net


from src.unet.unet_model import UNet from src.unet.unet_model import UNet




+ 1
- 2
model_zoo/official/cv/warpctc/export.py View File

@@ -16,8 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np


from mindspore import Tensor, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export


from src.warpctc import StackedRNN from src.warpctc import StackedRNN
from src.config import config from src.config import config


+ 1
- 1
tests/st/quantization/lenet_quant/test_lenet_quant.py View File

@@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import load_checkpoint, load_param_into_net, export
from mindspore.train import Model from mindspore.train import Model
from mindspore.compression.quant import QuantizationAwareTraining from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net


Loading…
Cancel
Save