|
|
|
@@ -22,7 +22,7 @@ import mindspore.context as context |
|
|
|
import mindspore.nn as nn |
|
|
|
import mindspore.ops.functional as F |
|
|
|
from mindspore import Tensor |
|
|
|
from mindspore.common.initializer import One |
|
|
|
from mindspore.common.initializer import TruncatedNormal |
|
|
|
from mindspore.communication.management import init |
|
|
|
from mindspore.nn.loss.loss import _Loss |
|
|
|
from mindspore.nn.optim.momentum import Momentum |
|
|
|
@@ -35,10 +35,11 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
context.set_context(device_id=int(os.getenv('DEVICE_ID'))) |
|
|
|
init() |
|
|
|
context.set_auto_parallel_context(mirror_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL) |
|
|
|
np.random.seed(10) |
|
|
|
|
|
|
|
|
|
|
|
def weight_variable(): |
|
|
|
return One() |
|
|
|
return TruncatedNormal(0.01) |
|
|
|
|
|
|
|
|
|
|
|
def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): |
|
|
|
@@ -93,11 +94,9 @@ class BasicBlock(nn.Cell): |
|
|
|
identity = x |
|
|
|
|
|
|
|
x = self.conv1(x) |
|
|
|
x = self.bn1(x) |
|
|
|
x = self.relu(x) |
|
|
|
|
|
|
|
x = self.conv2(x) |
|
|
|
x = self.bn2(x) |
|
|
|
|
|
|
|
if self.downsample: |
|
|
|
identity = self.down_sample_layer(identity) |
|
|
|
@@ -120,13 +119,10 @@ class ResidualBlock(nn.Cell): |
|
|
|
|
|
|
|
out_chls = out_channels // self.expansion |
|
|
|
self.conv1 = _conv1x1(in_channels, out_chls, stride=1) |
|
|
|
self.bn1 = _fused_bn(out_chls, momentum=momentum) |
|
|
|
|
|
|
|
self.conv2 = _conv3x3(out_chls, out_chls, stride=stride) |
|
|
|
self.bn2 = _fused_bn(out_chls, momentum=momentum) |
|
|
|
|
|
|
|
self.conv3 = _conv1x1(out_chls, out_channels, stride=1) |
|
|
|
self.bn3 = _fused_bn(out_channels, momentum=momentum) |
|
|
|
|
|
|
|
self.relu = P.ReLU() |
|
|
|
self.downsample = (in_channels != out_channels) |
|
|
|
@@ -134,7 +130,6 @@ class ResidualBlock(nn.Cell): |
|
|
|
if self.downsample: |
|
|
|
self.conv_down_sample = _conv1x1(in_channels, out_channels, |
|
|
|
stride=stride) |
|
|
|
self.bn_down_sample = _fused_bn(out_channels, momentum=momentum) |
|
|
|
elif self.stride != 1: |
|
|
|
self.maxpool_down = nn.MaxPool2d(kernel_size=1, stride=2, pad_mode='same') |
|
|
|
|
|
|
|
@@ -144,19 +139,15 @@ class ResidualBlock(nn.Cell): |
|
|
|
identity = x |
|
|
|
|
|
|
|
out = self.conv1(x) |
|
|
|
out = self.bn1(out) |
|
|
|
out = self.relu(out) |
|
|
|
|
|
|
|
out = self.conv2(out) |
|
|
|
out = self.bn2(out) |
|
|
|
out = self.relu(out) |
|
|
|
|
|
|
|
out = self.conv3(out) |
|
|
|
out = self.bn3(out) |
|
|
|
|
|
|
|
if self.downsample: |
|
|
|
identity = self.conv_down_sample(identity) |
|
|
|
identity = self.bn_down_sample(identity) |
|
|
|
elif self.stride != 1: |
|
|
|
identity = self.maxpool_down(identity) |
|
|
|
|
|
|
|
@@ -211,7 +202,7 @@ class ResNet(nn.Cell): |
|
|
|
self.mean = P.ReduceMean(keep_dims=True) |
|
|
|
self.end_point = nn.Dense(2048, num_classes, has_bias=True, |
|
|
|
weight_init=weight_variable(), |
|
|
|
bias_init=weight_variable()) |
|
|
|
bias_init=weight_variable()).add_flags_recursive(fp16=True) |
|
|
|
self.squeeze = P.Squeeze() |
|
|
|
self.cast = P.Cast() |
|
|
|
|
|
|
|
@@ -231,7 +222,6 @@ class ResNet(nn.Cell): |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.conv1(x) |
|
|
|
x = self.bn1(x) |
|
|
|
x = self.relu(x) |
|
|
|
c1 = self.maxpool(x) |
|
|
|
|
|
|
|
@@ -277,6 +267,7 @@ class SoftmaxCrossEntropyExpand(_Loss): |
|
|
|
self.eps = Tensor(1e-24, mstype.float32) |
|
|
|
|
|
|
|
def construct(self, logit, label): |
|
|
|
logit = self.cast(logit, mstype.float32) |
|
|
|
logit_max = self.max(logit, -1) |
|
|
|
exp = self.exp(self.sub(logit, logit_max)) |
|
|
|
exp_sum = self.sum(exp, -1) |
|
|
|
@@ -369,41 +360,19 @@ class ModelCallback(Callback): |
|
|
|
self.loss_list.append(result.asnumpy().mean()) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
|
@pytest.mark.env_single |
|
|
|
def test_train_feed(num_classes=8192): |
|
|
|
def test_train_feed(num_classes=65536): |
|
|
|
set_algo_parameters(elementwise_op_strategy_follow=True) |
|
|
|
parallel_callback = ModelCallback() |
|
|
|
data_gen = DataGenerator() |
|
|
|
_, input_part = data_gen.input_data((32 * 2, 3, 224, 224)) |
|
|
|
_, label_part = data_gen.label_data((32 * 2,)) |
|
|
|
_, input_part = data_gen.input_data((32 * 8, 3, 224, 224)) |
|
|
|
_, label_part = data_gen.label_data((32 * 8,)) |
|
|
|
dataset = Dataset(input_part, label_part) |
|
|
|
net = resnet50(num_classes) |
|
|
|
loss = SoftmaxCrossEntropyExpand(sparse=True) |
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 10.0, 0.9) |
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) |
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt) |
|
|
|
model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback) |
|
|
|
loss_value = np.array(parallel_callback.loss_list) |
|
|
|
expect_out = [9.010913, 8.855984, 8.56246, 8.146317, 7.624489] |
|
|
|
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
|
@pytest.mark.env_single |
|
|
|
def test_train_feed2(num_classes=1001): |
|
|
|
set_algo_parameters(elementwise_op_strategy_follow=True) |
|
|
|
parallel_callback = ModelCallback() |
|
|
|
data_gen = DataGenerator() |
|
|
|
_, input_part = data_gen.input_data((32 * 2, 3, 224, 224)) |
|
|
|
_, label_part = data_gen.label_data((32 * 2,)) |
|
|
|
dataset = Dataset(input_part, label_part) |
|
|
|
net = resnet50(num_classes) |
|
|
|
loss = SoftmaxCrossEntropyExpand(sparse=True) |
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 10.0, 0.9) |
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt) |
|
|
|
model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback) |
|
|
|
loss_value = np.array(parallel_callback.loss_list) |
|
|
|
expect_out = [6.908755, 6.8358116, 6.6986914, 6.506859, 6.2708097] |
|
|
|
expect_out = [11.11153, 11.090023, 11.050361, 10.994822, 10.924148] |
|
|
|
print(loss_value) |
|
|
|
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001) |