| @@ -32,7 +32,36 @@ from mindspore.nn import Cell | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as CP | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.nn.wrap.cell_wrapper import WithLossCell | |||
| from mindspore.train.callback import LossMonitor, Callback | |||
| from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits | |||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||
| from mindspore.train.model import Model | |||
| class MyTimeMonitor(Callback): | |||
| def __init__(self, data_size): | |||
| super(MyTimeMonitor, self).__init__() | |||
| self.data_size = data_size | |||
| self.total = 0 | |||
| def epoch_begin(self, run_context): | |||
| self.epoch_time = time.time() | |||
| def epoch_end(self, run_context): | |||
| epoch_msseconds = (time.time()-self.epoch_time) * 1000 | |||
| per_step_mssconds = epoch_msseconds / self.data_size | |||
| print("epoch time:{0}, per step time:{1}".format(epoch_msseconds, per_step_mssconds), flush=True) | |||
| def step_begin(self, run_context): | |||
| self.step_time = time.time() | |||
| def step_end(self, run_context): | |||
| step_msseconds = (time.time() - self.step_time) * 1000 | |||
| if step_msseconds < 265: | |||
| self.total = self.total + 1 | |||
| print(f"step time:{step_msseconds}", flush=True) | |||
| def good_step(self): | |||
| return self.total | |||
| random.seed(1) | |||
| np.random.seed(1) | |||
| @@ -303,12 +332,12 @@ def resnet50(batch_size, num_classes): | |||
| return ResNet(ResidualBlock, num_classes, batch_size) | |||
| def create_dataset(repeat_num=1, training=True, batch_size=32): | |||
| def create_dataset(repeat_num=1, training=True, batch_size=32, num_samples=1600): | |||
| data_home = "/home/workspace/mindspore_dataset" | |||
| data_dir = data_home + "/cifar-10-batches-bin" | |||
| if not training: | |||
| data_dir = data_home + "/cifar-10-verify-bin" | |||
| data_set = ds.Cifar10Dataset(data_dir) | |||
| data_set = ds.Cifar10Dataset(data_dir, num_samples=num_samples) | |||
| resize_height = 224 | |||
| resize_width = 224 | |||
| @@ -385,33 +414,25 @@ def test_pynative_resnet50(): | |||
| batch_size = 32 | |||
| num_classes = 10 | |||
| loss_scale = 128 | |||
| total_step = 50 | |||
| net = resnet50(batch_size, num_classes) | |||
| criterion = CrossEntropyLoss() | |||
| optimizer = Momentum(learning_rate=0.01, momentum=0.9, | |||
| params=filter(lambda x: x.requires_grad, net.get_parameters())) | |||
| data_set = create_dataset(repeat_num=1, training=True, batch_size=batch_size, num_samples=total_step * batch_size) | |||
| # define callbacks | |||
| time_cb = MyTimeMonitor(data_size=data_set.get_dataset_size()) | |||
| loss_cb = LossMonitor() | |||
| cb = [time_cb, loss_cb] | |||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||
| loss_scale = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=False) | |||
| model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale, metrics={'acc'}, | |||
| amp_level="O2", keep_batchnorm_fp32=False) | |||
| # train model | |||
| model.train(1, data_set, callbacks=cb, | |||
| sink_size=data_set.get_dataset_size(), dataset_sink_mode=True) | |||
| net_with_criterion = WithLossCell(net, criterion) | |||
| net_with_criterion.set_grad() | |||
| train_network = GradWrap(net_with_criterion) | |||
| train_network.set_train() | |||
| step = 0 | |||
| max_step = 21 | |||
| exceed_num = 0 | |||
| data_set = create_dataset(repeat_num=1, training=True, batch_size=batch_size) | |||
| for element in data_set.create_dict_iterator(num_epochs=1): | |||
| step = step + 1 | |||
| if step > max_step: | |||
| break | |||
| start_time = time.time() | |||
| input_data = element["image"] | |||
| input_label = element["label"] | |||
| loss_output = net_with_criterion(input_data, input_label) | |||
| grads = train_network(input_data, input_label) | |||
| optimizer(grads) | |||
| end_time = time.time() | |||
| cost_time = end_time - start_time | |||
| print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) | |||
| if step > 1 and cost_time > 0.25: | |||
| exceed_num = exceed_num + 1 | |||
| assert exceed_num < 20 | |||
| assert time_cb.good_step() > 10 | |||