Browse Source

!17300 Enable Graph Kernel for Lenet and LenetQuant on GPU

From: @zengzitao
Reviewed-by: @ckey_dou,@gaoxiong1
Signed-off-by: @ckey_dou
tags/v1.3.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
53e8cacede
2 changed files with 7 additions and 0 deletions
  1. +5
    -0
      model_zoo/official/cv/lenet/train.py
  2. +2
    -0
      model_zoo/official/cv/lenet_quant/train_quant.py

+ 5
- 0
model_zoo/official/cv/lenet/train.py View File

@@ -33,9 +33,11 @@ from mindspore.common import set_seed

set_seed(1)


def modelarts_pre_process():
pass


@moxing_wrapper(pre_process=modelarts_pre_process)
def train_lenet():

@@ -53,6 +55,8 @@ def train_lenet():
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.ckpt_path, config=config_ck)

if config.device_target != "Ascend":
if config.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
else:
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2")
@@ -60,5 +64,6 @@ def train_lenet():
print("============== Starting Training ==============")
model.train(config.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])


if __name__ == "__main__":
train_lenet()

+ 2
- 0
model_zoo/official/cv/lenet_quant/train_quant.py View File

@@ -51,6 +51,8 @@ if __name__ == "__main__":
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1)
step_size = ds_train.get_dataset_size()

if args.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
# define fusion network
network = LeNet5Fusion(cfg.num_classes)



Loading…
Cancel
Save