|
|
|
@@ -33,9 +33,11 @@ from mindspore.common import set_seed |
|
|
|
|
|
|
|
set_seed(1) |
|
|
|
|
|
|
|
|
|
|
|
def modelarts_pre_process(): |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
@moxing_wrapper(pre_process=modelarts_pre_process) |
|
|
|
def train_lenet(): |
|
|
|
|
|
|
|
@@ -53,6 +55,8 @@ def train_lenet(): |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.ckpt_path, config=config_ck) |
|
|
|
|
|
|
|
if config.device_target != "Ascend": |
|
|
|
if config.device_target == "GPU": |
|
|
|
context.set_context(enable_graph_kernel=True) |
|
|
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) |
|
|
|
else: |
|
|
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2") |
|
|
|
@@ -60,5 +64,6 @@ def train_lenet(): |
|
|
|
print("============== Starting Training ==============") |
|
|
|
model.train(config.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()]) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
train_lenet() |