Browse Source

!6357 add grad acc test case

Merge pull request !6357 from jinyaohui/master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
daff211538
2 changed files with 220 additions and 2 deletions
  1. +0
    -2
      mindspore/ccsrc/pipeline/jit/parse/parse.cc
  2. +220
    -0
      tests/st/networks/test_gradient_accumulation.py

+ 0
- 2
mindspore/ccsrc/pipeline/jit/parse/parse.cc View File

@@ -1816,8 +1816,6 @@ FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph);
}

UpdataParam(func_graph, cell);

// ret = cell_obj(*arg, *kwargs)
auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), {param_vargs, param_vkwargs});



+ 220
- 0
tests/st/networks/test_gradient_accumulation.py View File

@@ -0,0 +1,220 @@
import os

import pytest

import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as CT
import mindspore.dataset.vision.c_transforms as CV
import mindspore.nn as nn
from mindspore import ParameterTuple
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.common.initializer import Normal
from mindspore.dataset.vision import Inter
from mindspore.nn import Cell
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train.dataset_helper import DatasetHelper
from mindspore.train.serialization import save_checkpoint

_sum_op = C.MultitypeFuncGraph("grad_sum_op")
_clear_op = C.MultitypeFuncGraph("clear_op")


@_sum_op.register("Tensor", "Tensor")
def _cumulative_gard(grad_sum, grad):
"""Apply gard sum to cumulative gradient."""
add = P.AssignAdd()
return add(grad_sum, grad)


@_clear_op.register("Tensor", "Tensor")
def _clear_grad_sum(grad_sum, zero):
"""Apply zero to clear grad_sum."""
success = True
success = F.depend(success, F.assign(grad_sum, zero))
return success


class LeNet5(nn.Cell):
"""
Lenet network

Args:
num_class (int): Num classes. Default: 10.
num_channel (int): Num channels. Default: 1.

Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()

def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x


class TrainForwardBackward(Cell):
def __init__(self, network, optimizer, grad_sum, sens=1.0):
super(TrainForwardBackward, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad_sum = grad_sum
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.hyper_map = C.HyperMap()

def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*inputs, sens)
return F.depend(loss, self.hyper_map(F.partial(_sum_op), self.grad_sum, grads))


class TrainOptim(Cell):
def __init__(self, optimizer, grad_sum):
super(TrainOptim, self).__init__(auto_prefix=False)
self.optimizer = optimizer
self.grad_sum = grad_sum

def construct(self):
return self.optimizer(self.grad_sum)


class TrainClear(Cell):
def __init__(self, grad_sum, zeros):
super(TrainClear, self).__init__(auto_prefix=False)
self.grad_sum = grad_sum
self.zeros = zeros
self.hyper_map = C.HyperMap()

def construct(self):
seccess = self.hyper_map(F.partial(_clear_op), self.grad_sum, self.zeros)
return seccess


class GradientAccumulation:
def __init__(self, network, loss_fn, optimizer):
self._network = network
self._loss_fn = loss_fn
self._optimizer = optimizer

params = self._optimizer.parameters
self._grad_sum = params.clone(prefix="grad_sum", init='zeros')
self._zeros = params.clone(prefix="zeros", init='zeros')
self._train_forward_backward = self._build_train_forward_backward_network()
self._train_optim = self._build_train_optim()
self._train_clear = self._build_train_clear()

def _build_train_forward_backward_network(self):
"""Build forward and backward network"""
network = self._network
network = nn.WithLossCell(network, self._loss_fn)
loss_scale = 1.0
network = TrainForwardBackward(network, self._optimizer, self._grad_sum, loss_scale).set_train()
return network

def _build_train_optim(self):
"""Build optimizer network"""
network = TrainOptim(self._optimizer, self._grad_sum).set_train()
return network

def _build_train_clear(self):
"""Build clear network"""
network = TrainClear(self._grad_sum, self._zeros).set_train()
return network

def train_process(self, epoch, train_dataset, mini_steps=None):
"""
Training process. The data would be passed to network directly.
"""
dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False, epoch_num=epoch)

for i in range(epoch):
step = 0
for k, next_element in enumerate(dataset_helper):
loss = self._train_forward_backward(*next_element)
if (k + 1) % mini_steps == 0:
step += 1
print("epoch:", i + 1, "step:", step, "loss is ", loss)
self._train_optim()
self._train_clear()

train_dataset.reset()

save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt",)


def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)

resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081

# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = CT.TypeCast(mstype.int32)

# apply map operations on images
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)

# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)

return mnist_ds


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_gradient_accumulation():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
ds_train = create_dataset(os.path.join("/home/workspace/mindspore_dataset/mnist", "train"), 32)

network = LeNet5(10)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
model = GradientAccumulation(network, net_loss, net_opt)

print("============== Starting Training ==============")
model.train_process(2, ds_train, mini_steps=4)

Loading…
Cancel
Save