| @@ -22,7 +22,7 @@ from mindspore import Tensor | |||||
| from mindspore.nn.optim.momentum import Momentum | from mindspore.nn.optim.momentum import Momentum | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.train.callback import Callback, LossMonitor, ModelCheckpoint, CheckpointConfig | |||||
| from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig | |||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | from mindspore.train.loss_scale_manager import FixedLossScaleManager | ||||
| from mindspore.communication.management import init, get_rank, get_group_size | from mindspore.communication.management import init, get_rank, get_group_size | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| @@ -59,13 +59,33 @@ class MyTimeMonitor(Callback): | |||||
| def step_begin(self, run_context): | def step_begin(self, run_context): | ||||
| self.step_time = time.time() | self.step_time = time.time() | ||||
| def step_end(self, run_context): | def step_end(self, run_context): | ||||
| cb_params = run_context.original_args() | |||||
| loss = cb_params.net_outputs | |||||
| if isinstance(loss, (tuple, list)): | |||||
| if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): | |||||
| loss = loss[0] | |||||
| if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): | |||||
| loss = np.mean(loss.asnumpy()) | |||||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||||
| if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): | |||||
| raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( | |||||
| cb_params.cur_epoch_num, cur_step_in_epoch)) | |||||
| step_mseconds = (time.time() - self.step_time) * 1000 | step_mseconds = (time.time() - self.step_time) * 1000 | ||||
| fps = self.batch_size / step_mseconds *1000 * self.size | fps = self.batch_size / step_mseconds *1000 * self.size | ||||
| print("Epoch time: {:5.3f} ms, fps: {:d} img/sec.".format(step_mseconds, int(fps)), flush=True, end=" ") | |||||
| def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU", dtype="fp16"): | |||||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True) | |||||
| print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), | |||||
| "Epoch time: {:5.3f} ms, fps: {:d} img/sec.".format(step_mseconds, int(fps)), flush=True) | |||||
| def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU", dtype="fp16", | |||||
| device_num=1): | |||||
| if device_num == 1: | |||||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True) | |||||
| else: | |||||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True, | |||||
| num_shards=device_num, shard_id=get_rank()) | |||||
| image_size = 224 | image_size = 224 | ||||
| mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] | mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] | ||||
| std = [0.229 * 255, 0.224 * 255, 0.225 * 255] | std = [0.229 * 255, 0.224 * 255, 0.225 * 255] | ||||
| @@ -185,8 +205,7 @@ def train(): | |||||
| if mode == context.PYNATIVE_MODE: | if mode == context.PYNATIVE_MODE: | ||||
| print_per_steps = 1 | print_per_steps = 1 | ||||
| time_cb = MyTimeMonitor(total_batch, print_per_steps) | time_cb = MyTimeMonitor(total_batch, print_per_steps) | ||||
| loss_cb = LossMonitor() | |||||
| cb = [time_cb, loss_cb] | |||||
| cb = [time_cb] | |||||
| if save_ckpt: | if save_ckpt: | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=5 * step_size, keep_checkpoint_max=5) | config_ck = CheckpointConfig(save_checkpoint_steps=5 * step_size, keep_checkpoint_max=5) | ||||
| ckpt_cb = ModelCheckpoint(prefix="resnet_benchmark", directory=ckpt_save_dir, config=config_ck) | ckpt_cb = ModelCheckpoint(prefix="resnet_benchmark", directory=ckpt_save_dir, config=config_ck) | ||||