# Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import pytest import numpy as np import mindspore.nn as nn from mindspore.nn.optim import Momentum from mindspore.ops import operations as P from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import Dense from mindspore import Tensor import mindspore.context as context context.set_context(mode=context.GRAPH_MODE, device_target="CPU") class MomentumNet(nn.Cell): def __init__(self): super(MomentumNet, self).__init__() self.batch_size = 1 self.reshape = P.Reshape() weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) self.fc1 = Dense(16, 10, weight_init=weight) def construct(self, input_x): output = self.reshape(input_x, (self.batch_size, -1)) output = self.fc1(output) return output @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_momentum(): epoch = 13 net = MomentumNet() learning_rate = 0.1 momentum = 0.9 optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum) criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) net_with_criterion = WithLossCell(net, criterion) train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer train_network.set_train() losses = [] for i in range(epoch): data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01) label = Tensor(np.array([0]).astype(np.int32)) loss = train_network(data, label) losses.append(loss) print("================================") print(losses) """ expect output: [[0.04132498 0.00874167 0.00874167 0.00874167 0.00874167 0.00874167 0.00874167 0.00874167 0.00874167 0.00874167]] """ error = np.ones(shape=[1, 10]) * 1.0e-6 return losses