|
|
|
@@ -13,7 +13,7 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
"""Transformer evaluation script.""" |
|
|
|
|
|
|
|
import os |
|
|
|
import argparse |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
@@ -41,8 +41,13 @@ def run_gru_eval(): |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \ |
|
|
|
device_id=args.device_id, save_graphs=False) |
|
|
|
prefix = "multi30k_test_mindrecord_32" |
|
|
|
mindrecord_file = os.path.join(args.dataset_path, prefix) |
|
|
|
if not os.path.exists(mindrecord_file): |
|
|
|
print("dataset file {} not exists, please check!".format(mindrecord_file)) |
|
|
|
raise ValueError(mindrecord_file) |
|
|
|
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \ |
|
|
|
dataset_path=args.dataset_path, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False) |
|
|
|
dataset_path=mindrecord_file, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False) |
|
|
|
dataset_size = dataset.get_dataset_size() |
|
|
|
print("dataset size is {}".format(dataset_size)) |
|
|
|
network = Seq2Seq(config, is_training=False) |
|
|
|
|