From: @bai-yangfan Reviewed-by: @chenfei52,@zh_qh Signed-off-by: @zh_qhtags/v1.1.0
| @@ -20,5 +20,11 @@ Helper functions in train piplines. | |||||
| from .model import Model | from .model import Model | ||||
| from .dataset_helper import DatasetHelper, connect_network_with_dataset | from .dataset_helper import DatasetHelper, connect_network_with_dataset | ||||
| from . import amp | from . import amp | ||||
| from .amp import build_train_network | |||||
| from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager | |||||
| from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, parse_print,\ | |||||
| build_searched_strategy, merge_sliced_parameter | |||||
| __all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset"] | |||||
| __all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager", | |||||
| "FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint", | |||||
| "load_param_into_net", "export", "parse_print", "build_searched_strategy", "merge_sliced_parameter"] | |||||
| @@ -26,8 +26,6 @@ from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager | |||||
| from ..context import ParallelMode | from ..context import ParallelMode | ||||
| from .. import context | from .. import context | ||||
| __all__ = ["build_train_network"] | |||||
| class OutputTo16(nn.Cell): | class OutputTo16(nn.Cell): | ||||
| "Wrap cell for amp. Cast network output back to float16" | "Wrap cell for amp. Cast network output back to float16" | ||||
| @@ -17,8 +17,6 @@ | |||||
| from .._checkparam import Validator as validator | from .._checkparam import Validator as validator | ||||
| from .. import nn | from .. import nn | ||||
| __all__ = ["LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager"] | |||||
| class LossScaleManager: | class LossScaleManager: | ||||
| """Loss scale manager abstract class.""" | """Loss scale manager abstract class.""" | ||||
| @@ -33,8 +33,6 @@ from mindspore._checkparam import check_input_data, Validator | |||||
| from mindspore.compression.export import quant_export | from mindspore.compression.export import quant_export | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print", | |||||
| "build_searched_strategy", "merge_sliced_parameter"] | |||||
| tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, | tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, | ||||
| "Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64, | "Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64, | ||||
| @@ -20,9 +20,7 @@ import argparse | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export | |||||
| from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg | from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg | ||||
| from src.alexnet import AlexNet | from src.alexnet import AlexNet | ||||
| @@ -16,9 +16,8 @@ | |||||
| import argparse | import argparse | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import Tensor, context | |||||
| from mindspore import Tensor, context, load_checkpoint, export | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore.train.serialization import load_checkpoint, export | |||||
| from src.config import Config_CNNCTC | from src.config import Config_CNNCTC | ||||
| from src.cnn_ctc import CNNCTC_Model | from src.cnn_ctc import CNNCTC_Model | ||||
| @@ -16,10 +16,7 @@ | |||||
| import argparse | import argparse | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.train.serialization import export | |||||
| from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export | |||||
| from src.nets import net_factory | from src.nets import net_factory | ||||
| @@ -17,8 +17,7 @@ import argparse | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import Tensor, load_checkpoint, load_param_into_net, export | |||||
| from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 | from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 | ||||
| from src.config import config | from src.config import config | ||||
| @@ -20,8 +20,7 @@ import argparse | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import Tensor, load_checkpoint, load_param_into_net, export | |||||
| from src.config import cifar_cfg, imagenet_cfg | from src.config import cifar_cfg, imagenet_cfg | ||||
| from src.googlenet import GoogleNet | from src.googlenet import GoogleNet | ||||
| @@ -19,8 +19,7 @@ import argparse | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import Tensor, load_checkpoint, load_param_into_net, export | |||||
| from src.config import config_gpu as cfg | from src.config import config_gpu as cfg | ||||
| from src.inception_v3 import InceptionV3 | from src.inception_v3 import InceptionV3 | ||||
| @@ -20,9 +20,7 @@ import argparse | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore | import mindspore | ||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export | |||||
| from src.config import mnist_cfg as cfg | from src.config import mnist_cfg as cfg | ||||
| from src.lenet import LeNet5 | from src.lenet import LeNet5 | ||||
| @@ -20,10 +20,8 @@ import argparse | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore | import mindspore | ||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | from mindspore.compression.quant import QuantizationAwareTraining | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from src.config import mnist_cfg as cfg | from src.config import mnist_cfg as cfg | ||||
| from src.lenet_fusion import LeNet5 as LeNet5Fusion | from src.lenet_fusion import LeNet5 as LeNet5Fusion | ||||
| @@ -16,8 +16,7 @@ | |||||
| import argparse | import argparse | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import Tensor, context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export | |||||
| from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 | from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 | ||||
| from src.config import config | from src.config import config | ||||
| @@ -16,8 +16,7 @@ | |||||
| mobilenetv2 export mindir. | mobilenetv2 export mindir. | ||||
| """ | """ | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import export | |||||
| from mindspore import Tensor, export | |||||
| from src.config import set_config | from src.config import set_config | ||||
| from src.args import export_parse_args | from src.args import export_parse_args | ||||
| from src.models import define_net, load_ckpt | from src.models import define_net, load_ckpt | ||||
| @@ -18,9 +18,7 @@ import argparse | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore | import mindspore | ||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export | |||||
| from mindspore.compression.quant import QuantizationAwareTraining | from mindspore.compression.quant import QuantizationAwareTraining | ||||
| from src.mobilenetV2 import mobilenetV2 | from src.mobilenetV2 import mobilenetV2 | ||||
| @@ -17,8 +17,7 @@ mobilenetv3 export mindir. | |||||
| """ | """ | ||||
| import argparse | import argparse | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import context, Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export | |||||
| from src.config import config_gpu | from src.config import config_gpu | ||||
| from src.mobilenetV3 import mobilenet_v3_large | from src.mobilenetV3 import mobilenet_v3_large | ||||
| @@ -19,8 +19,7 @@ import argparse | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import Tensor, load_checkpoint, load_param_into_net, export | |||||
| from src.config import nasnet_a_mobile_config_gpu as cfg | from src.config import nasnet_a_mobile_config_gpu as cfg | ||||
| from src.nasnet_a_mobile import NASNetAMobile | from src.nasnet_a_mobile import NASNetAMobile | ||||
| @@ -19,8 +19,7 @@ import argparse | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import Tensor, load_checkpoint, load_param_into_net, export | |||||
| from src.config import config | from src.config import config | ||||
| from src.ETSNET.etsnet import ETSNet | from src.ETSNET.etsnet import ETSNet | ||||
| @@ -19,8 +19,7 @@ python export.py | |||||
| import argparse | import argparse | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import Tensor, load_checkpoint, load_param_into_net, export | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| parser = argparse.ArgumentParser(description='resnet export') | parser = argparse.ArgumentParser(description='resnet export') | ||||
| @@ -16,9 +16,7 @@ | |||||
| import argparse | import argparse | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import context | |||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export | |||||
| from src.resnet_thor import resnet50 as resnet | from src.resnet_thor import resnet50 as resnet | ||||
| from src.config import config | from src.config import config | ||||
| @@ -17,8 +17,7 @@ resnext export mindir. | |||||
| """ | """ | ||||
| import argparse | import argparse | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import context, Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export | |||||
| from src.config import config | from src.config import config | ||||
| from src.image_classification import get_network | from src.image_classification import get_network | ||||
| @@ -17,8 +17,7 @@ ssd export mindir. | |||||
| """ | """ | ||||
| import argparse | import argparse | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import context, Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export | |||||
| from src.ssd import SSD300, ssd_mobilenet_v2 | from src.ssd import SSD300, ssd_mobilenet_v2 | ||||
| from src.config import config | from src.config import config | ||||
| @@ -16,8 +16,7 @@ | |||||
| import argparse | import argparse | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import export, load_checkpoint, load_param_into_net | |||||
| from mindspore import Tensor, export, load_checkpoint, load_param_into_net | |||||
| from src.unet.unet_model import UNet | from src.unet.unet_model import UNet | ||||
| @@ -16,8 +16,7 @@ | |||||
| import argparse | import argparse | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import Tensor, context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export | |||||
| from src.warpctc import StackedRNN | from src.warpctc import StackedRNN | ||||
| from src.config import config | from src.config import config | ||||
| @@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.nn.metrics import Accuracy | from mindspore.nn.metrics import Accuracy | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from mindspore import load_checkpoint, load_param_into_net, export | |||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.compression.quant import QuantizationAwareTraining | from mindspore.compression.quant import QuantizationAwareTraining | ||||
| from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net | from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net | ||||