|
|
|
@@ -0,0 +1,220 @@ |
|
|
|
import os |
|
|
|
|
|
|
|
import pytest |
|
|
|
|
|
|
|
import mindspore.dataset as ds |
|
|
|
import mindspore.dataset.transforms.c_transforms as CT |
|
|
|
import mindspore.dataset.vision.c_transforms as CV |
|
|
|
import mindspore.nn as nn |
|
|
|
from mindspore import ParameterTuple |
|
|
|
from mindspore import context |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
from mindspore.common.initializer import Normal |
|
|
|
from mindspore.dataset.vision import Inter |
|
|
|
from mindspore.nn import Cell |
|
|
|
from mindspore.ops import composite as C |
|
|
|
from mindspore.ops import functional as F |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.train.dataset_helper import DatasetHelper |
|
|
|
from mindspore.train.serialization import save_checkpoint |
|
|
|
|
|
|
|
_sum_op = C.MultitypeFuncGraph("grad_sum_op") |
|
|
|
_clear_op = C.MultitypeFuncGraph("clear_op") |
|
|
|
|
|
|
|
|
|
|
|
@_sum_op.register("Tensor", "Tensor") |
|
|
|
def _cumulative_gard(grad_sum, grad): |
|
|
|
"""Apply gard sum to cumulative gradient.""" |
|
|
|
add = P.AssignAdd() |
|
|
|
return add(grad_sum, grad) |
|
|
|
|
|
|
|
|
|
|
|
@_clear_op.register("Tensor", "Tensor") |
|
|
|
def _clear_grad_sum(grad_sum, zero): |
|
|
|
"""Apply zero to clear grad_sum.""" |
|
|
|
success = True |
|
|
|
success = F.depend(success, F.assign(grad_sum, zero)) |
|
|
|
return success |
|
|
|
|
|
|
|
|
|
|
|
class LeNet5(nn.Cell): |
|
|
|
""" |
|
|
|
Lenet network |
|
|
|
|
|
|
|
Args: |
|
|
|
num_class (int): Num classes. Default: 10. |
|
|
|
num_channel (int): Num channels. Default: 1. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Tensor, output tensor |
|
|
|
Examples: |
|
|
|
>>> LeNet(num_class=10) |
|
|
|
""" |
|
|
|
def __init__(self, num_class=10, num_channel=1): |
|
|
|
super(LeNet5, self).__init__() |
|
|
|
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') |
|
|
|
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') |
|
|
|
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) |
|
|
|
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) |
|
|
|
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) |
|
|
|
self.relu = nn.ReLU() |
|
|
|
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
self.flatten = nn.Flatten() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.max_pool2d(self.relu(self.conv1(x))) |
|
|
|
x = self.max_pool2d(self.relu(self.conv2(x))) |
|
|
|
x = self.flatten(x) |
|
|
|
x = self.relu(self.fc1(x)) |
|
|
|
x = self.relu(self.fc2(x)) |
|
|
|
x = self.fc3(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class TrainForwardBackward(Cell): |
|
|
|
def __init__(self, network, optimizer, grad_sum, sens=1.0): |
|
|
|
super(TrainForwardBackward, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
|
self.network.add_flags(defer_inline=True) |
|
|
|
self.weights = ParameterTuple(network.trainable_params()) |
|
|
|
self.optimizer = optimizer |
|
|
|
self.grad_sum = grad_sum |
|
|
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True) |
|
|
|
self.sens = sens |
|
|
|
self.hyper_map = C.HyperMap() |
|
|
|
|
|
|
|
def construct(self, *inputs): |
|
|
|
weights = self.weights |
|
|
|
loss = self.network(*inputs) |
|
|
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) |
|
|
|
grads = self.grad(self.network, weights)(*inputs, sens) |
|
|
|
return F.depend(loss, self.hyper_map(F.partial(_sum_op), self.grad_sum, grads)) |
|
|
|
|
|
|
|
|
|
|
|
class TrainOptim(Cell): |
|
|
|
def __init__(self, optimizer, grad_sum): |
|
|
|
super(TrainOptim, self).__init__(auto_prefix=False) |
|
|
|
self.optimizer = optimizer |
|
|
|
self.grad_sum = grad_sum |
|
|
|
|
|
|
|
def construct(self): |
|
|
|
return self.optimizer(self.grad_sum) |
|
|
|
|
|
|
|
|
|
|
|
class TrainClear(Cell): |
|
|
|
def __init__(self, grad_sum, zeros): |
|
|
|
super(TrainClear, self).__init__(auto_prefix=False) |
|
|
|
self.grad_sum = grad_sum |
|
|
|
self.zeros = zeros |
|
|
|
self.hyper_map = C.HyperMap() |
|
|
|
|
|
|
|
def construct(self): |
|
|
|
seccess = self.hyper_map(F.partial(_clear_op), self.grad_sum, self.zeros) |
|
|
|
return seccess |
|
|
|
|
|
|
|
|
|
|
|
class GradientAccumulation: |
|
|
|
def __init__(self, network, loss_fn, optimizer): |
|
|
|
self._network = network |
|
|
|
self._loss_fn = loss_fn |
|
|
|
self._optimizer = optimizer |
|
|
|
|
|
|
|
params = self._optimizer.parameters |
|
|
|
self._grad_sum = params.clone(prefix="grad_sum", init='zeros') |
|
|
|
self._zeros = params.clone(prefix="zeros", init='zeros') |
|
|
|
self._train_forward_backward = self._build_train_forward_backward_network() |
|
|
|
self._train_optim = self._build_train_optim() |
|
|
|
self._train_clear = self._build_train_clear() |
|
|
|
|
|
|
|
def _build_train_forward_backward_network(self): |
|
|
|
"""Build forward and backward network""" |
|
|
|
network = self._network |
|
|
|
network = nn.WithLossCell(network, self._loss_fn) |
|
|
|
loss_scale = 1.0 |
|
|
|
network = TrainForwardBackward(network, self._optimizer, self._grad_sum, loss_scale).set_train() |
|
|
|
return network |
|
|
|
|
|
|
|
def _build_train_optim(self): |
|
|
|
"""Build optimizer network""" |
|
|
|
network = TrainOptim(self._optimizer, self._grad_sum).set_train() |
|
|
|
return network |
|
|
|
|
|
|
|
def _build_train_clear(self): |
|
|
|
"""Build clear network""" |
|
|
|
network = TrainClear(self._grad_sum, self._zeros).set_train() |
|
|
|
return network |
|
|
|
|
|
|
|
def train_process(self, epoch, train_dataset, mini_steps=None): |
|
|
|
""" |
|
|
|
Training process. The data would be passed to network directly. |
|
|
|
""" |
|
|
|
dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False, epoch_num=epoch) |
|
|
|
|
|
|
|
for i in range(epoch): |
|
|
|
step = 0 |
|
|
|
for k, next_element in enumerate(dataset_helper): |
|
|
|
loss = self._train_forward_backward(*next_element) |
|
|
|
if (k + 1) % mini_steps == 0: |
|
|
|
step += 1 |
|
|
|
print("epoch:", i + 1, "step:", step, "loss is ", loss) |
|
|
|
self._train_optim() |
|
|
|
self._train_clear() |
|
|
|
|
|
|
|
train_dataset.reset() |
|
|
|
|
|
|
|
save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt",) |
|
|
|
|
|
|
|
|
|
|
|
def create_dataset(data_path, batch_size=32, repeat_size=1, |
|
|
|
num_parallel_workers=1): |
|
|
|
""" |
|
|
|
create dataset for train or test |
|
|
|
""" |
|
|
|
# define dataset |
|
|
|
mnist_ds = ds.MnistDataset(data_path) |
|
|
|
|
|
|
|
resize_height, resize_width = 32, 32 |
|
|
|
rescale = 1.0 / 255.0 |
|
|
|
shift = 0.0 |
|
|
|
rescale_nml = 1 / 0.3081 |
|
|
|
shift_nml = -1 * 0.1307 / 0.3081 |
|
|
|
|
|
|
|
# define map operations |
|
|
|
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode |
|
|
|
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) |
|
|
|
rescale_op = CV.Rescale(rescale, shift) |
|
|
|
hwc2chw_op = CV.HWC2CHW() |
|
|
|
type_cast_op = CT.TypeCast(mstype.int32) |
|
|
|
|
|
|
|
# apply map operations on images |
|
|
|
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) |
|
|
|
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) |
|
|
|
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) |
|
|
|
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) |
|
|
|
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) |
|
|
|
|
|
|
|
# apply DatasetOps |
|
|
|
buffer_size = 10000 |
|
|
|
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script |
|
|
|
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) |
|
|
|
mnist_ds = mnist_ds.repeat(repeat_size) |
|
|
|
|
|
|
|
return mnist_ds |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
|
@pytest.mark.env_onecard |
|
|
|
def test_gradient_accumulation(): |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
ds_train = create_dataset(os.path.join("/home/workspace/mindspore_dataset/mnist", "train"), 32) |
|
|
|
|
|
|
|
network = LeNet5(10) |
|
|
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") |
|
|
|
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) |
|
|
|
model = GradientAccumulation(network, net_loss, net_opt) |
|
|
|
|
|
|
|
print("============== Starting Training ==============") |
|
|
|
model.train_process(2, ds_train, mini_steps=4) |