From a152c510c5d263d5216ae40f3f627b911089b19a Mon Sep 17 00:00:00 2001 From: Yanjun Peng Date: Mon, 28 Dec 2020 19:44:42 +0800 Subject: [PATCH] fix textcnn export on master branch --- model_zoo/official/cv/dpn/train.py | 2 +- model_zoo/official/nlp/textcnn/export.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/model_zoo/official/cv/dpn/train.py b/model_zoo/official/cv/dpn/train.py index ccedc04e38..283c38cf8f 100644 --- a/model_zoo/official/cv/dpn/train.py +++ b/model_zoo/official/cv/dpn/train.py @@ -51,7 +51,7 @@ def parse_args(): # distributed related parser.add_argument('--is_distributed', type=int, default=1, help='if multi device') parser.add_argument('--ckpt_path', type=str, default='', help='ckpt path to save') - parser.add_argument('--eval_each_epoch', type=int, default=0, help='ckpt path to save') + parser.add_argument('--eval_each_epoch', type=int, default=0, help='evaluate on each epoch') args, _ = parser.parse_known_args() args.image_size = config.image_size args.num_classes = config.num_classes diff --git a/model_zoo/official/nlp/textcnn/export.py b/model_zoo/official/nlp/textcnn/export.py index 311475bfc0..fa08c9ca71 100644 --- a/model_zoo/official/nlp/textcnn/export.py +++ b/model_zoo/official/nlp/textcnn/export.py @@ -21,9 +21,9 @@ import numpy as np from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context -from src.config import cfg +from src.config import cfg_mr, cfg_subj, cfg_sst2 from src.textcnn import TextCNN -from src.dataset import MovieReview +from src.dataset import MovieReview, SST2, Subjectivity parser = argparse.ArgumentParser(description='TextCNN export') parser.add_argument("--device_id", type=int, default=0, help="device id") @@ -32,7 +32,7 @@ parser.add_argument("--file_name", type=str, default="textcnn", help="output fil parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", help="device target") -parser.add_argument('--dataset_name', type=str, default='MR', choices=['MR'], +parser.add_argument('--dataset', type=str, default='MR', choices=['MR', 'SUBJ', 'SST2'], help='dataset name.') args = parser.parse_args() @@ -41,8 +41,15 @@ context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, d if __name__ == '__main__': - if args.dataset_name == 'MR': + if args.dataset == 'MR': + cfg = cfg_mr instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) + elif args.dataset == 'SUBJ': + cfg = cfg_subj + instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) + elif args.dataset == 'SST2': + cfg = cfg_sst2 + instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) else: raise ValueError("dataset is not support.")