| @@ -0,0 +1,172 @@ | |||||
| # Contents | |||||
| - [CNN-Direction-Model Description](#cnn-direction-model-description) | |||||
| - [Model Architecture](#model-architecture) | |||||
| - [Dataset](#dataset) | |||||
| - [Environment Requirements](#environment-requirements) | |||||
| - [Quick Start](#quick-start) | |||||
| - [Script Description](#script-description) | |||||
| - [Script and Sample Code](#script-and-sample-code) | |||||
| - [Script Parameters](#script-parameters) | |||||
| - [Training Process](#training-process) | |||||
| - [Training](#training) | |||||
| - [Evaluation Process](#evaluation-process) | |||||
| - [Evaluation](#evaluation) | |||||
| - [Model Description](#model-description) | |||||
| - [Performance](#performance) | |||||
| - [Evaluation Performance](#evaluation-performance) | |||||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||||
| # [CNN-Direction-Model Description](#contents) | |||||
| CNN Direction Model is a model designed to perform binary classification of text images on whether the text in the image is going from left-to-right or right-to-left. | |||||
| # [Model Architecture](#contents) | |||||
| CNN Direction Model's composition consists of 1 convolutional layer and 4 residual blocks for feature extraction. The feature extraction stage is then followed by 3 dense layers to perform the classification. | |||||
| # [Dataset](#contents) | |||||
| Dataset used: [FSNS (French Street Name Signs)](https://arxiv.org/abs/1702.03970) | |||||
| - Dataset size:~200GB,~1M 150*600 colored images with a label indicating the text within the image. | |||||
| - Train:200GB,1M, images | |||||
| - Test:4GB,24,404 images | |||||
| - Data format:binary files | |||||
| - Note:Data will be processed in dataset.py | |||||
| - Download the dataset, the recommened directory structure to have is as follows: | |||||
| Annotations for training and testing should be in test_annot and train_annot. | |||||
| Training and Testing images should be in train and test. | |||||
| ```shell | |||||
| ├─test | |||||
| │ | |||||
| └─test_annot | |||||
| │ | |||||
| └─train | |||||
| │ | |||||
| └─train_annot | |||||
| ``` | |||||
| - After downloading the data and converting it to it's raw format (.txt for annotations and .jpg, .jpeg, or .png for the images), add the image and annotations paths to the src/config.py file then cd to src and run: | |||||
| ```python | |||||
| python create_mindrecord.py | |||||
| ``` | |||||
| This will create two folders: train and test in the target directory you specify in config.py. | |||||
| # [Environment Requirements](#contents) | |||||
| - Hardware(Ascend/GPU) | |||||
| - Prepare hardware environment with Ascend or GPU processor. | |||||
| - Framework | |||||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||||
| - For more information, please check the resources below: | |||||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||||
| # [Quick Start](#contents) | |||||
| After installing MindSpore via the official website, you can start training and evaluation as follows: | |||||
| ```python | |||||
| # enter script dir, train CNNDirectionModel | |||||
| sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) | |||||
| # enter script dir, evaluate CNNDirectionModel | |||||
| sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH] | |||||
| ``` | |||||
| # [Script Description](#contents) | |||||
| ## [Script and Sample Code](#contents) | |||||
| ```shell | |||||
| ├── cv | |||||
| ├── cnn_direction_model | |||||
| ├── README.md // descriptions about cnn_direction_model | |||||
| ├── requirements.txt // packages needed | |||||
| ├── scripts | |||||
| │ ├──run_distribute_train_ascend.sh // distributed training in ascend | |||||
| │ ├──run_standalone_eval_ascend.sh // evaluate in ascend | |||||
| │ ├──run_standalone_train_ascend.sh // train standalone in ascend | |||||
| ├── src | |||||
| │ ├──dataset.py // creating dataset | |||||
| │ ├──cnn_direction_model.py // cnn_direction_model architecture | |||||
| │ ├──config.py // parameter configuration | |||||
| │ ├──create_mindrecord.py // convert raw data to mindrecords | |||||
| ├── train.py // training script | |||||
| ├── eval.py // evaluation script | |||||
| ``` | |||||
| ## [Script Parameters](#contents) | |||||
| ```python | |||||
| Major parameters in config.py as follows: | |||||
| --data_root_train: The path to the raw training data images for conversion to mindrecord script. | |||||
| --data_root_test: The path to the raw test data images for conversion to mindrecord script. | |||||
| --test_annotation_file: The path to the raw training annotation file. | |||||
| --train_annotation_file: The path to the raw test annotation file. | |||||
| --mindrecord_dir: The path to which create_mindrecord.py uses to save the resulting mindrecords for training and testing. | |||||
| --epoch_size: Total training epochs. | |||||
| --batch_size: Training batch size. | |||||
| --im_size_h: Image height used as input to the model. | |||||
| --im_size_w: Image width used as input the model. | |||||
| ``` | |||||
| ## [Training Process](#contents) | |||||
| ### Training | |||||
| - running on Ascend | |||||
| ```python | |||||
| sh run_standalone_train_ascend.sh path-to-train-mindrecords pre-trained-chkpt(optional) | |||||
| ``` | |||||
| The model checkpoint will be saved script/train. | |||||
| ## [Evaluation Process](#contents) | |||||
| ### Evaluation | |||||
| Before running the command below, please check the checkpoint path used for evaluation. | |||||
| - running on Ascend | |||||
| ```python | |||||
| sh run_standalone_eval_ascend.sh path-to-test-mindrecords trained-chkpt-path | |||||
| ``` | |||||
| Results of evaluation will be printed after evaluation process is completed. | |||||
| # [Model Description](#contents) | |||||
| ## [Performance](#contents) | |||||
| ### Evaluation Performance | |||||
| | Parameters | Ascend | | |||||
| | -------------------------- | ------------------------------------------------------------| | |||||
| | Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G | | |||||
| | uploaded Date | 01/15/2020 (month/day/year) | | |||||
| | MindSpore Version | 1.1 | | |||||
| | Dataset | FSNS | | |||||
| | Training Parameters | epoch=1, steps=104,477, batch_size = 20, lr=1e-07 | | |||||
| | Optimizer | Adam | | |||||
| | Loss Function | Softmax Cross Entropy | | |||||
| | outputs | top 1 accuracy | | |||||
| | Overall accuracy | 91.72% | | |||||
| | Speed | 583 ms/step | | |||||
| | Total time | 17 hours | | |||||
| # [Description of Random Situation](#contents) | |||||
| In train.py, we set some seeds before training. | |||||
| # [ModelZoo Homepage](#contents) | |||||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -12,7 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """train resnet.""" | |||||
| """test direction model.""" | |||||
| import argparse | import argparse | ||||
| import os | import os | ||||
| import random | import random | ||||
| @@ -45,8 +45,10 @@ if __name__ == '__main__': | |||||
| context.set_context(device_id=device_id) | context.set_context(device_id=device_id) | ||||
| # create dataset | # create dataset | ||||
| dataset = create_dataset_eval(args_opt.dataset_path + "/ocr_eval_pos.mindrecord", config=config) | |||||
| step_size = dataset.get_dataset_size() | |||||
| dataset_name = config.dataset_name | |||||
| dataset_lr, dataset_rl = create_dataset_eval(args_opt.dataset_path + "/" + dataset_name + | |||||
| ".mindrecord0", config=config, dataset_name=dataset_name) | |||||
| step_size = dataset_lr.get_dataset_size() | |||||
| print("step_size ", step_size) | print("step_size ", step_size) | ||||
| @@ -65,5 +67,7 @@ if __name__ == '__main__': | |||||
| model = Model(net, loss_fn=loss, metrics={'top_1_accuracy'}) | model = Model(net, loss_fn=loss, metrics={'top_1_accuracy'}) | ||||
| # eval model | # eval model | ||||
| res = model.eval(dataset, dataset_sink_mode=False) | |||||
| print("result:", res, "ckpt=", args_opt.checkpoint_path) | |||||
| res_lr = model.eval(dataset_lr, dataset_sink_mode=False) | |||||
| res_rl = model.eval(dataset_rl, dataset_sink_mode=False) | |||||
| print("result on upright images:", res_lr, "ckpt=", args_opt.checkpoint_path) | |||||
| print("result on 180 degrees rotated images:", res_rl, "ckpt=", args_opt.checkpoint_path) | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -17,11 +17,24 @@ network config setting, will be used in train.py and eval.py | |||||
| """ | """ | ||||
| from easydict import EasyDict as ed | from easydict import EasyDict as ed | ||||
| config1 = ed({ | config1 = ed({ | ||||
| # dataset metadata | |||||
| "dataset_name": "fsns", | |||||
| # annotation files paths | |||||
| "train_annotation_file": "path-to-file", | |||||
| "test_annotation_file": "path-to-file", | |||||
| # dataset root paths | |||||
| "data_root_train": "path-to-dir", | |||||
| "data_root_test": "path-to-dir", | |||||
| # mindrecord target locations | |||||
| "mindrecord_dir": "path-to-dir", | |||||
| # training and testing params | |||||
| "batch_size": 8, | "batch_size": 8, | ||||
| "epoch_size": 5, | "epoch_size": 5, | ||||
| "pretrain_epoch_size": 0, | "pretrain_epoch_size": 0, | ||||
| "save_checkpoint": True, | "save_checkpoint": True, | ||||
| "save_checkpoint_steps": 2500, | |||||
| "save_checkpoint_epochs": 10, | "save_checkpoint_epochs": 10, | ||||
| "keep_checkpoint_max": 20, | "keep_checkpoint_max": 20, | ||||
| "save_checkpoint_path": "./", | "save_checkpoint_path": "./", | ||||
| @@ -0,0 +1,108 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import os | |||||
| from mindspore.mindrecord import FileWriter | |||||
| from config import config1 as config | |||||
| FAIL = 1 | |||||
| SUCCESS = 0 | |||||
| def get_images(image_dir, annot_files): | |||||
| """ | |||||
| Get file paths that are in image_dir, annotation file is used to get the file names. | |||||
| Args: | |||||
| image_dir(string): images directory. | |||||
| annot_files(list(string)) : annotation files. | |||||
| Returns: | |||||
| status code(int), status of process(string), image ids(list(int)), image paths(dict(int,string)) | |||||
| """ | |||||
| print("Process [Get Images] started") | |||||
| if not os.path.isdir(image_dir): | |||||
| return FAIL, "{} is not a directory. Please check the src/config.py file.".format(image_dir), [], {} | |||||
| image_files_dict = {} | |||||
| images = [] | |||||
| img_id = 0 | |||||
| # create a dictionary of image file paths | |||||
| for annot_file in annot_files: | |||||
| if not os.path.exists(annot_file): | |||||
| return FAIL, "{} was not found.".format(annot_file), [], {} | |||||
| lines = open(annot_file, 'r').readlines() | |||||
| for line in lines: | |||||
| # extract file name | |||||
| file_name = line.split('\t')[0] | |||||
| image_path = os.path.join(image_dir, file_name) | |||||
| if not os.path.isfile(image_path): | |||||
| return FAIL, "{} is not a file.".format(image_path), [], {} | |||||
| # add path to dictionary | |||||
| images.append(img_id) | |||||
| image_files_dict[img_id] = image_path | |||||
| img_id += 1 | |||||
| return SUCCESS, "Successfully retrieved {} images.".format(str(len(images))), images, image_files_dict | |||||
| def write_mindrecord_images(image_ids, image_dict, mindrecord_dir, data_schema, file_num=8): | |||||
| writer = FileWriter(os.path.join(mindrecord_dir, config.dataset_name + ".mindrecord"), shard_num=file_num) | |||||
| writer.add_schema(data_schema, config.dataset_name) | |||||
| len_image_dict = len(image_dict) | |||||
| sample_count = 0 | |||||
| for img_id in image_ids: | |||||
| image_path = image_dict[img_id] | |||||
| with open(image_path, 'rb') as f: | |||||
| img = f.read() | |||||
| row = {"image": img} | |||||
| sample_count += 1 | |||||
| writer.write_raw_data([row]) | |||||
| print("Progress {} / {}".format(str(sample_count), str(len_image_dict)), end='\r') | |||||
| writer.commit() | |||||
| def create_mindrecord(): | |||||
| annot_files_train = [config.train_annotation_file] | |||||
| annot_files_test = [config.test_annotation_file] | |||||
| ret_code, ret_message, images_train, image_path_dict_train = get_images(image_dir=config.data_root_train, | |||||
| annot_files=annot_files_train) | |||||
| if ret_code != SUCCESS: | |||||
| return ret_code, message, "", "" | |||||
| ret_code, ret_message, images_test, image_path_dict_test = get_images(image_dir=config.data_root_test, | |||||
| annot_files=annot_files_test) | |||||
| if ret_code != SUCCESS: | |||||
| return ret_code, ret_message, "", "" | |||||
| data_schema = {"image": {"type": "bytes"}} | |||||
| train_target = os.path.join(config.mindrecord_dir, "train") | |||||
| test_target = os.path.join(config.mindrecord_dir, "test") | |||||
| if not os.path.exists(train_target): | |||||
| os.mkdir(train_target) | |||||
| if not os.path.exists(test_target): | |||||
| os.mkdir(test_target) | |||||
| print("Creating training mindrecords: ") | |||||
| write_mindrecord_images(images_train, image_path_dict_train, train_target, data_schema) | |||||
| print("Creating test mindrecords: ") | |||||
| write_mindrecord_images(images_test, image_path_dict_test, test_target, data_schema) | |||||
| return SUCCESS, "Successful mindrecord creation.", train_target, test_target | |||||
| if __name__ == "__main__": | |||||
| # start creating mindrecords from raw images and annots | |||||
| # provide root path to raw data in the config file | |||||
| code, message, train_target_dir, test_target_dir = create_mindrecord() | |||||
| if code != SUCCESS: | |||||
| print("Process done with status code: {}. Error: {}".format(code, message)) | |||||
| else: | |||||
| print("Process done with status: {}. Training and testing data are saved to {} and {} respectively." | |||||
| .format(message, train_target_dir, test_target_dir)) | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -139,9 +139,17 @@ def rotate_and_set_neg(img, label): | |||||
| label = label - 1 | label = label - 1 | ||||
| img_rotate = np.rot90(img) | img_rotate = np.rot90(img) | ||||
| img_rotate = np.rot90(img_rotate) | img_rotate = np.rot90(img_rotate) | ||||
| # return img_rotate, label | |||||
| return img_rotate, np.array(label).astype(np.int32) | return img_rotate, np.array(label).astype(np.int32) | ||||
| def crop_image(h_crop, w_crop): | |||||
| def crop_fun(img): | |||||
| return img[h_crop[0]:h_crop[1], w_crop[0]:w_crop[1], :] | |||||
| return crop_fun | |||||
| def create_label(label=1): | |||||
| def label_fun(img): | |||||
| return img, np.array(label).astype(np.int32) | |||||
| return label_fun | |||||
| def rotate(img, label): | def rotate(img, label): | ||||
| img_rotate = np.rot90(img) | img_rotate = np.rot90(img) | ||||
| @@ -165,13 +173,14 @@ def transform_image(img, label): | |||||
| return data.transpose((0, 3, 1, 2))[0], label | return data.transpose((0, 3, 1, 2))[0], label | ||||
| def create_dataset_train(mindrecord_file_pos, config): | |||||
| def create_dataset_train(mindrecord_file_pos, config, dataset_name='ocr'): | |||||
| """ | """ | ||||
| create a train dataset | create a train dataset | ||||
| Args: | Args: | ||||
| mindrecord_file_pos(string): mindrecord file for positive samples. | mindrecord_file_pos(string): mindrecord file for positive samples. | ||||
| config(dict): config of dataset. | config(dict): config of dataset. | ||||
| dataset_name(string): name of dataset being used, e.g. 'fsns'. | |||||
| Returns: | Returns: | ||||
| dataset | dataset | ||||
| @@ -179,11 +188,15 @@ def create_dataset_train(mindrecord_file_pos, config): | |||||
| rank_size = int(os.getenv("RANK_SIZE", '1')) | rank_size = int(os.getenv("RANK_SIZE", '1')) | ||||
| rank_id = int(os.getenv("RANK_ID", '0')) | rank_id = int(os.getenv("RANK_ID", '0')) | ||||
| decode = C.Decode() | decode = C.Decode() | ||||
| data_set = ds.MindDataset(mindrecord_file_pos, columns_list=["image", "label"], num_parallel_workers=4, | |||||
| columns_list = ["image", "label"] if dataset_name != 'fsns' else ["image"] | |||||
| data_set = ds.MindDataset(mindrecord_file_pos, columns_list=columns_list, num_parallel_workers=4, | |||||
| num_shards=rank_size, shard_id=rank_id, shuffle=True) | num_shards=rank_size, shard_id=rank_id, shuffle=True) | ||||
| data_set = data_set.map(operations=decode, input_columns=["image"], num_parallel_workers=8) | data_set = data_set.map(operations=decode, input_columns=["image"], num_parallel_workers=8) | ||||
| if dataset_name == 'fsns': | |||||
| data_set = data_set.map(operations=crop_image((0, 150), (0, 150)), | |||||
| input_columns=["image"], num_parallel_workers=8) | |||||
| data_set = data_set.map(operations=create_label(), input_columns=["image"], output_columns=["image", "label"], | |||||
| column_order=["image", "label"], num_parallel_workers=8) | |||||
| augmentor = Augmentor(config.augment_severity, config.augment_prob) | augmentor = Augmentor(config.augment_severity, config.augment_prob) | ||||
| operation = augmentor.process | operation = augmentor.process | ||||
| data_set = data_set.map(operations=operation, input_columns=["image"], | data_set = data_set.map(operations=operation, input_columns=["image"], | ||||
| @@ -217,7 +230,7 @@ def resize_image(img, label): | |||||
| return data.transpose((0, 3, 1, 2))[0], label | return data.transpose((0, 3, 1, 2))[0], label | ||||
| def create_dataset_eval(mindrecord_file_pos, config): | |||||
| def create_dataset_eval(mindrecord_file_pos, config, dataset_name='ocr'): | |||||
| """ | """ | ||||
| create an eval dataset | create an eval dataset | ||||
| @@ -226,16 +239,21 @@ def create_dataset_eval(mindrecord_file_pos, config): | |||||
| config(dict): config of dataset. | config(dict): config of dataset. | ||||
| Returns: | Returns: | ||||
| dataset | |||||
| dataset with images upright | |||||
| dataset with images 180-degrees rotated | |||||
| """ | """ | ||||
| rank_size = int(os.getenv("RANK_SIZE", '1')) | rank_size = int(os.getenv("RANK_SIZE", '1')) | ||||
| rank_id = int(os.getenv("RANK_ID", '0')) | rank_id = int(os.getenv("RANK_ID", '0')) | ||||
| decode = C.Decode() | decode = C.Decode() | ||||
| data_set = ds.MindDataset(mindrecord_file_pos, columns_list=["image", "label"], num_parallel_workers=1, | |||||
| columns_list = ["image", "label"] if dataset_name != 'fsns' else ["image"] | |||||
| data_set = ds.MindDataset(mindrecord_file_pos, columns_list=columns_list, num_parallel_workers=1, | |||||
| num_shards=rank_size, shard_id=rank_id, shuffle=False) | num_shards=rank_size, shard_id=rank_id, shuffle=False) | ||||
| data_set = data_set.map(operations=decode, input_columns=["image"], num_parallel_workers=8) | data_set = data_set.map(operations=decode, input_columns=["image"], num_parallel_workers=8) | ||||
| if dataset_name == 'fsns': | |||||
| data_set = data_set.map(operations=crop_image((0, 150), (0, 150)), | |||||
| input_columns=["image"], num_parallel_workers=8) | |||||
| data_set = data_set.map(operations=create_label(), input_columns=["image"], output_columns=["image", "label"], | |||||
| column_order=["image", "label"], num_parallel_workers=8) | |||||
| global image_height | global image_height | ||||
| global image_width | global image_width | ||||
| image_height = config.im_size_h | image_height = config.im_size_h | ||||
| @@ -243,7 +261,12 @@ def create_dataset_eval(mindrecord_file_pos, config): | |||||
| data_set = data_set.map(operations=resize_image, input_columns=["image", "label"], | data_set = data_set.map(operations=resize_image, input_columns=["image", "label"], | ||||
| num_parallel_workers=config.work_nums, | num_parallel_workers=config.work_nums, | ||||
| python_multiprocessing=False) | python_multiprocessing=False) | ||||
| dataset_lr, dataset_rl = data_set.split([0.5, 0.5]) | |||||
| dataset_rl = dataset_rl.map(operations=rotate_and_set_neg, input_columns=["image", "label"], | |||||
| num_parallel_workers=config.work_nums, | |||||
| python_multiprocessing=False) | |||||
| # apply batch operations | # apply batch operations | ||||
| data_set = data_set.batch(1, drop_remainder=True) | |||||
| dataset_lr = dataset_lr.batch(1, drop_remainder=True) | |||||
| dataset_rl = dataset_rl.batch(1, drop_remainder=True) | |||||
| return data_set | |||||
| return dataset_lr, dataset_rl | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -38,7 +38,6 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| parser = argparse.ArgumentParser(description='Image classification') | parser = argparse.ArgumentParser(description='Image classification') | ||||
| parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') | parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') | ||||
| parser.add_argument('--device_num', type=int, default=1, help='Device num.') | parser.add_argument('--device_num', type=int, default=1, help='Device num.') | ||||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | ||||
| parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') | parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') | ||||
| parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') | parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') | ||||
| @@ -72,7 +71,9 @@ if __name__ == '__main__': | |||||
| init() | init() | ||||
| # create dataset | # create dataset | ||||
| dataset = create_dataset_train(args_opt.dataset_path + "/ocr_pos.mindrecord0", config=config) | |||||
| dataset_name = config.dataset_name | |||||
| dataset = create_dataset_train(args_opt.dataset_path + "/" + dataset_name + | |||||
| ".mindrecord0", config=config, dataset_name=dataset_name) | |||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| # define net | # define net | ||||
| @@ -99,7 +100,7 @@ if __name__ == '__main__': | |||||
| loss_cb = LossMonitor() | loss_cb = LossMonitor() | ||||
| cb = [time_cb, loss_cb] | cb = [time_cb, loss_cb] | ||||
| if config.save_checkpoint: | if config.save_checkpoint: | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=2500, | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, | |||||
| keep_checkpoint_max=config.keep_checkpoint_max) | keep_checkpoint_max=config.keep_checkpoint_max) | ||||
| ckpt_cb = ModelCheckpoint(prefix="cnn_direction_model", directory=ckpt_save_dir, config=config_ck) | ckpt_cb = ModelCheckpoint(prefix="cnn_direction_model", directory=ckpt_save_dir, config=config_ck) | ||||
| cb += [ckpt_cb] | cb += [ckpt_cb] | ||||