You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ftrl.py 3.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test FTRL """
  16. import pytest
  17. import numpy as np
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, Parameter, context
  20. from mindspore.common.api import _executor
  21. from mindspore.nn import TrainOneStepCell, WithLossCell
  22. from mindspore.nn.optim import FTRL
  23. from mindspore.ops import operations as P
  24. @pytest.fixture(scope="module", autouse=True)
  25. def setup_teardown():
  26. context.set_context(enable_sparse=True)
  27. yield
  28. context.set_context(enable_sparse=False)
  29. class Net(nn.Cell):
  30. def __init__(self):
  31. super(Net, self).__init__()
  32. self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name='weight')
  33. self.bias = Parameter(Tensor(np.ones([10]).astype(np.float32)), name='bias')
  34. self.matmul = P.MatMul()
  35. self.biasAdd = P.BiasAdd()
  36. def construct(self, x):
  37. x = self.biasAdd(self.matmul(x, self.weight), self.bias)
  38. return x
  39. class NetWithSparseGatherV2(nn.Cell):
  40. """ NetWithSparseGatherV2 definition """
  41. def __init__(self):
  42. super(NetWithSparseGatherV2, self).__init__()
  43. self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
  44. self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
  45. self.axis = 0
  46. self.gather = P.SparseGatherV2()
  47. def construct(self, indices, label):
  48. return self.gather(self.weight1, indices, self.axis) + self.weight2
  49. def test_ftrl():
  50. """ test_ftrl """
  51. inputs = Tensor(np.ones([1, 64]).astype(np.float32))
  52. label = Tensor(np.zeros([1, 10]).astype(np.float32))
  53. net = Net()
  54. net.set_train()
  55. loss = nn.SoftmaxCrossEntropyWithLogits()
  56. optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0)
  57. net_with_loss = WithLossCell(net, loss)
  58. train_network = TrainOneStepCell(net_with_loss, optimizer)
  59. _executor.compile(train_network, inputs, label)
  60. def test_spares_ftrl_compile():
  61. """ test sparse ftrl compile """
  62. indices = Tensor(np.array([0, 1]).astype(np.int32))
  63. label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
  64. net = NetWithSparseGatherV2()
  65. net.set_train()
  66. optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0)
  67. optimizer.target = 'CPU'
  68. train_network = TrainOneStepCell(net, optimizer)
  69. _executor.compile(train_network, indices, label)
  70. def test_spares_ftrl():
  71. """ test sparse ftrl"""
  72. indices = Tensor(np.array([0, 1]).astype(np.int32))
  73. label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
  74. net = NetWithSparseGatherV2()
  75. net.set_train()
  76. optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0)
  77. optimizer.target = 'Ascend'
  78. train_network = TrainOneStepCell(net, optimizer)
  79. _executor.compile(train_network, indices, label)