| @@ -0,0 +1,114 @@ | |||||
| """ | |||||
| ######################## train lenet example ######################## | |||||
| train lenet and get network model files(.ckpt) | |||||
| """ | |||||
| #!/usr/bin/python | |||||
| #coding=utf-8 | |||||
| import os | |||||
| import argparse | |||||
| import moxing as mox | |||||
| from config import mnist_cfg as cfg | |||||
| from dataset import create_dataset | |||||
| from dataset_distributed import create_dataset_parallel | |||||
| from lenet import LeNet5 | |||||
| import json | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||||
| from mindspore.train import Model | |||||
| from mindspore.nn.metrics import Accuracy | |||||
| from mindspore import load_checkpoint, load_param_into_net | |||||
| from mindspore.context import ParallelMode | |||||
| from mindspore.communication.management import init, get_rank | |||||
| import time | |||||
| ### Copy multiple datasets from obs to training image ### | |||||
| def MultiObsToEnv(multi_data_url, data_dir): | |||||
| #--multi_data_url is json data, need to do json parsing for multi_data_url | |||||
| multi_data_json = json.loads(multi_data_url) | |||||
| for i in range(len(multi_data_json)): | |||||
| path = data_dir + "/" + multi_data_json[i]["dataset_name"] | |||||
| file_path = data_dir + "/" + os.path.splitext(multi_data_json[i]["dataset_name"])[0] | |||||
| if not os.path.exists(file_path): | |||||
| os.makedirs(file_path) | |||||
| try: | |||||
| mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path) | |||||
| print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path)) | |||||
| #unzip dataset | |||||
| os.system("unzip -d %s %s" % (file_path, path)) | |||||
| except Exception as e: | |||||
| print('moxing download {} to {} failed: '.format( | |||||
| multi_data_json[i]["dataset_url"], path) + str(e)) | |||||
| #Set a cache file to determine whether the data has been copied to obs. | |||||
| #If this file exists during multi-card training, there is no need to copy the dataset multiple times. | |||||
| f = open("/cache/download_input.txt", 'w') | |||||
| f.close() | |||||
| try: | |||||
| if os.path.exists("/cache/download_input.txt"): | |||||
| print("download_input succeed") | |||||
| except Exception as e: | |||||
| print("download_input failed") | |||||
| return | |||||
| def DownloadFromQizhi(multi_data_url, data_dir): | |||||
| device_num = int(os.getenv('RANK_SIZE')) | |||||
| if device_num == 1: | |||||
| MultiObsToEnv(multi_data_url,data_dir) | |||||
| context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target) | |||||
| if device_num > 1: | |||||
| # set device_id and init for multi-card training | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID'))) | |||||
| context.reset_auto_parallel_context() | |||||
| context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True) | |||||
| init() | |||||
| #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data | |||||
| local_rank=int(os.getenv('RANK_ID')) | |||||
| if local_rank%8==0: | |||||
| MultiObsToEnv(multi_data_url,data_dir) | |||||
| #If the cache file does not exist, it means that the copy data has not been completed, | |||||
| #and Wait for 0th card to finish copying data | |||||
| while not os.path.exists("/cache/download_input.txt"): | |||||
| time.sleep(1) | |||||
| return | |||||
| parser = argparse.ArgumentParser(description='MindSpore Lenet Example') | |||||
| ### --multi_data_url,--ckpt_url,--device_target,These 4 parameters must be defined first in a multi-dataset, | |||||
| ### otherwise an error will be reported. | |||||
| ### There is no need to add these parameters to the running parameters of the Qizhi platform, | |||||
| ### because they are predefined in the background, you only need to define them in your code. | |||||
| parser.add_argument('--multi_data_url', | |||||
| help='dataset path in obs') | |||||
| parser.add_argument('--ckpt_url', | |||||
| help='pre_train_model path in obs') | |||||
| parser.add_argument( | |||||
| '--device_target', | |||||
| type=str, | |||||
| default="Ascend", | |||||
| choices=['Ascend', 'CPU'], | |||||
| help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU') | |||||
| parser.add_argument('--epoch_size', | |||||
| type=int, | |||||
| default=5, | |||||
| help='Training epochs.') | |||||
| if __name__ == "__main__": | |||||
| args, unknown = parser.parse_known_args() | |||||
| data_dir = '/cache/dataset' | |||||
| train_dir = '/cache/output' | |||||
| if not os.path.exists(data_dir): | |||||
| os.makedirs(data_dir) | |||||
| if not os.path.exists(train_dir): | |||||
| os.makedirs(train_dir) | |||||
| ###Initialize and copy data to training image | |||||
| DownloadFromQizhi(args.multi_data_url, data_dir) | |||||
| print("--------start ls:") | |||||
| os.system("cd /cache/dataset; ls -al") | |||||
| print("--------end ls-----------") | |||||