You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gather_v2_primitive.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. from mindspore.ops import composite as C
  17. from mindspore.common.parameter import ParameterTuple
  18. from mindspore.nn.optim import Momentum
  19. from mindspore.communication.management import init
  20. from mindspore.train import Model, ParallelMode
  21. import mindspore as ms
  22. import mindspore.nn as nn
  23. from mindspore.ops import operations as P
  24. from mindspore.ops import functional as F
  25. from mindspore.nn.loss.loss import _Loss
  26. from mindspore import Tensor
  27. from mindspore.common import dtype as mstype
  28. from mindspore.nn import Dense, Cell
  29. from mindspore import context
  30. context.set_context(mode=context.GRAPH_MODE)
  31. class Dataset():
  32. def __init__(self, predict, length=3):
  33. self.predict = predict
  34. self.index = 0
  35. self.length = length
  36. def __iter__(self):
  37. return self
  38. def __next__(self):
  39. if self.index >= self.length:
  40. raise StopIteration
  41. self.index += 1
  42. return (self.predict,)
  43. def reset(self):
  44. self.index = 0
  45. def get_dataset_size(self):
  46. return 128
  47. def get_repeat_count(self):
  48. return 1
  49. class GatherV2(_Loss):
  50. def __init__(self, batchsize):
  51. super(GatherV2, self).__init__()
  52. self.pow = P.Pow()
  53. emb_list = list(range(batchsize))
  54. emb1_list = emb_list[0::2]
  55. emb2_list = emb_list[1::2]
  56. self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
  57. self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
  58. self.gatherv2 = P.GatherV2()
  59. def construct(self, nembeddings):
  60. emb1 = self.gatherv2(nembeddings, self.emb1_param, 0)
  61. emb2 = self.gatherv2(nembeddings, self.emb2_param, 0)
  62. return self.pow((emb1 - emb2), 2.0)
  63. def get_loss(batchsize):
  64. return GatherV2(batchsize)
  65. def fc_with_initialize(input_channels, out_channels):
  66. return Dense(input_channels, out_channels)
  67. class BuildTrainNetwork(nn.Cell):
  68. def __init__(self, network, criterion):
  69. super(BuildTrainNetwork, self).__init__()
  70. self.network = network
  71. self.criterion = criterion
  72. def construct(self, input_data):
  73. embeddings = self.network(input_data)
  74. loss = self.criterion(embeddings)
  75. return loss
  76. class TrainOneStepCell(Cell):
  77. def __init__(self, network, optimizer, sens=1.0):
  78. super(TrainOneStepCell, self).__init__(auto_prefix=False)
  79. self.network = network
  80. self.network.add_flags(defer_inline=True)
  81. self.weights = ParameterTuple(network.trainable_params())
  82. self.optimizer = optimizer
  83. self.grad = C.GradOperation('grad',
  84. get_by_list=True,
  85. sens_param=True)
  86. self.sens = sens
  87. def construct(self, data):
  88. weights = self.weights
  89. loss = self.network(data)
  90. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  91. grads = self.grad(self.network, weights)(data, sens)
  92. return F.depend(loss, self.optimizer(grads))
  93. def test_trains():
  94. init()
  95. lr = 0.1
  96. momentum = 0.9
  97. max_epoch = 20
  98. device_number = 32
  99. batch_size_per_device = 128
  100. input_channels = 256
  101. out_channels = 512
  102. context.reset_auto_parallel_context()
  103. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number)
  104. predict = Tensor(np.ones([batch_size_per_device, input_channels]), dtype=ms.float32)
  105. dataset = Dataset(predict, 4)
  106. network = fc_with_initialize(input_channels, out_channels)
  107. network.set_train()
  108. criterion = get_loss(batch_size_per_device * device_number)
  109. train_network = BuildTrainNetwork(network, criterion)
  110. train_network.set_train()
  111. opt = Momentum(train_network.trainable_params(), lr, momentum)
  112. train_net = TrainOneStepCell(train_network, opt).set_train()
  113. model = Model(train_net)
  114. model.train(max_epoch, dataset, dataset_sink_mode=False)
  115. context.reset_auto_parallel_context()
  116. if __name__ == "__main__":
  117. test_trains()