""" ######################## train lenet example ######################## train lenet and get network model files(.ckpt) """ #!/usr/bin/python #coding=utf-8 import os import argparse import moxing as mox from config import mnist_cfg as cfg from dataset import create_dataset from dataset_distributed import create_dataset_parallel from lenet import LeNet5 import json import mindspore.nn as nn from mindspore import context from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train import Model from mindspore.nn.metrics import Accuracy from mindspore import load_checkpoint, load_param_into_net from mindspore.context import ParallelMode from mindspore.communication.management import init, get_rank import time ### Copy multiple datasets from obs to training image ### def MultiObsToEnv(multi_data_url, data_dir): #--multi_data_url is json data, need to do json parsing for multi_data_url multi_data_json = json.loads(multi_data_url) for i in range(len(multi_data_json)): path = data_dir + "/" + multi_data_json[i]["dataset_name"] file_path = data_dir + "/" + os.path.splitext(multi_data_json[i]["dataset_name"])[0] if not os.path.exists(file_path): os.makedirs(file_path) try: mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path) print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path)) #unzip dataset os.system("unzip -d %s %s" % (file_path, path)) except Exception as e: print('moxing download {} to {} failed: '.format( multi_data_json[i]["dataset_url"], path) + str(e)) #Set a cache file to determine whether the data has been copied to obs. #If this file exists during multi-card training, there is no need to copy the dataset multiple times. f = open("/cache/download_input.txt", 'w') f.close() try: if os.path.exists("/cache/download_input.txt"): print("download_input succeed") except Exception as e: print("download_input failed") return def DownloadFromQizhi(multi_data_url, data_dir): device_num = int(os.getenv('RANK_SIZE')) if device_num == 1: MultiObsToEnv(multi_data_url,data_dir) context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target) if device_num > 1: # set device_id and init for multi-card training context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID'))) context.reset_auto_parallel_context() context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True) init() #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data local_rank=int(os.getenv('RANK_ID')) if local_rank%8==0: MultiObsToEnv(multi_data_url,data_dir) #If the cache file does not exist, it means that the copy data has not been completed, #and Wait for 0th card to finish copying data while not os.path.exists("/cache/download_input.txt"): time.sleep(1) return parser = argparse.ArgumentParser(description='MindSpore Lenet Example') ### --multi_data_url,--ckpt_url,--device_target,These 4 parameters must be defined first in a multi-dataset, ### otherwise an error will be reported. ### There is no need to add these parameters to the running parameters of the Qizhi platform, ### because they are predefined in the background, you only need to define them in your code. parser.add_argument('--multi_data_url', help='dataset path in obs') parser.add_argument('--ckpt_url', help='pre_train_model path in obs') parser.add_argument( '--device_target', type=str, default="Ascend", choices=['Ascend', 'CPU'], help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU') parser.add_argument('--epoch_size', type=int, default=5, help='Training epochs.') if __name__ == "__main__": args, unknown = parser.parse_known_args() data_dir = '/cache/dataset' train_dir = '/cache/output' if not os.path.exists(data_dir): os.makedirs(data_dir) if not os.path.exists(train_dir): os.makedirs(train_dir) ###Initialize and copy data to training image DownloadFromQizhi(args.multi_data_url, data_dir) print("--------start ls:") os.system("cd /cache/dataset; ls -al") print("--------end ls-----------")