| @@ -0,0 +1,114 @@ | |||
| """ | |||
| ######################## train lenet example ######################## | |||
| train lenet and get network model files(.ckpt) | |||
| """ | |||
| #!/usr/bin/python | |||
| #coding=utf-8 | |||
| import os | |||
| import argparse | |||
| import moxing as mox | |||
| from config import mnist_cfg as cfg | |||
| from dataset import create_dataset | |||
| from dataset_distributed import create_dataset_parallel | |||
| from lenet import LeNet5 | |||
| import json | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from mindspore import load_checkpoint, load_param_into_net | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.communication.management import init, get_rank | |||
| import time | |||
| ### Copy multiple datasets from obs to training image ### | |||
| def MultiObsToEnv(multi_data_url, data_dir): | |||
| #--multi_data_url is json data, need to do json parsing for multi_data_url | |||
| multi_data_json = json.loads(multi_data_url) | |||
| for i in range(len(multi_data_json)): | |||
| path = data_dir + "/" + multi_data_json[i]["dataset_name"] | |||
| file_path = data_dir + "/" + os.path.splitext(multi_data_json[i]["dataset_name"])[0] | |||
| if not os.path.exists(file_path): | |||
| os.makedirs(file_path) | |||
| try: | |||
| mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path) | |||
| print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path)) | |||
| #unzip dataset | |||
| os.system("unzip -d %s %s" % (file_path, path)) | |||
| except Exception as e: | |||
| print('moxing download {} to {} failed: '.format( | |||
| multi_data_json[i]["dataset_url"], path) + str(e)) | |||
| #Set a cache file to determine whether the data has been copied to obs. | |||
| #If this file exists during multi-card training, there is no need to copy the dataset multiple times. | |||
| f = open("/cache/download_input.txt", 'w') | |||
| f.close() | |||
| try: | |||
| if os.path.exists("/cache/download_input.txt"): | |||
| print("download_input succeed") | |||
| except Exception as e: | |||
| print("download_input failed") | |||
| return | |||
| def DownloadFromQizhi(multi_data_url, data_dir): | |||
| device_num = int(os.getenv('RANK_SIZE')) | |||
| if device_num == 1: | |||
| MultiObsToEnv(multi_data_url,data_dir) | |||
| context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target) | |||
| if device_num > 1: | |||
| # set device_id and init for multi-card training | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID'))) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True) | |||
| init() | |||
| #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data | |||
| local_rank=int(os.getenv('RANK_ID')) | |||
| if local_rank%8==0: | |||
| MultiObsToEnv(multi_data_url,data_dir) | |||
| #If the cache file does not exist, it means that the copy data has not been completed, | |||
| #and Wait for 0th card to finish copying data | |||
| while not os.path.exists("/cache/download_input.txt"): | |||
| time.sleep(1) | |||
| return | |||
| parser = argparse.ArgumentParser(description='MindSpore Lenet Example') | |||
| ### --multi_data_url,--ckpt_url,--device_target,These 4 parameters must be defined first in a multi-dataset, | |||
| ### otherwise an error will be reported. | |||
| ### There is no need to add these parameters to the running parameters of the Qizhi platform, | |||
| ### because they are predefined in the background, you only need to define them in your code. | |||
| parser.add_argument('--multi_data_url', | |||
| help='dataset path in obs') | |||
| parser.add_argument('--ckpt_url', | |||
| help='pre_train_model path in obs') | |||
| parser.add_argument( | |||
| '--device_target', | |||
| type=str, | |||
| default="Ascend", | |||
| choices=['Ascend', 'CPU'], | |||
| help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU') | |||
| parser.add_argument('--epoch_size', | |||
| type=int, | |||
| default=5, | |||
| help='Training epochs.') | |||
| if __name__ == "__main__": | |||
| args, unknown = parser.parse_known_args() | |||
| data_dir = '/cache/dataset' | |||
| train_dir = '/cache/output' | |||
| if not os.path.exists(data_dir): | |||
| os.makedirs(data_dir) | |||
| if not os.path.exists(train_dir): | |||
| os.makedirs(train_dir) | |||
| ###Initialize and copy data to training image | |||
| DownloadFromQizhi(args.multi_data_url, data_dir) | |||
| print("--------start ls:") | |||
| os.system("cd /cache/dataset; ls -al") | |||
| print("--------end ls-----------") | |||