You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tdt_data_transfer.py 5.9 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import time
  16. import numpy as np
  17. import pytest
  18. from mindspore import context, nn, Tensor
  19. from mindspore import log as logger
  20. from mindspore.common.api import _cell_graph_executor
  21. from mindspore.common import dtype as mstype
  22. from mindspore.ops import operations as P
  23. import mindspore.dataset as de
  24. from mindspore.dataset.vision import c_transforms as c_vision
  25. from mindspore.dataset.transforms import c_transforms as c_trans
  26. DATA_DIR = "/home/workspace/mindspore_dataset/cifar-10-verify-bin"
  27. def dataset_cifar(dataset_path=None, batch_size=32, repeat_num=1, num_rows=9600, distribution_num=None, shard_id=None,
  28. drop_remainder=True, usage=None, shuffle=False, num_workers=8, resize_size=32, pad_info=None):
  29. if dataset_path is None:
  30. dataset_path = DATA_DIR
  31. ds = de.Cifar10Dataset(dataset_path, num_samples=num_rows, num_shards=distribution_num, shard_id=shard_id,
  32. shuffle=shuffle, usage=usage, num_parallel_workers=num_workers)
  33. typecast_op = c_trans.TypeCast(mstype.int32)
  34. ds = ds.map(input_columns="label", operations=typecast_op, num_parallel_workers=num_workers)
  35. image_op_list = [c_vision.Resize(resize_size),
  36. c_vision.Rescale(1.0 / 255.0, 0.0),
  37. c_vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  38. c_vision.HWC2CHW()]
  39. ds = ds.map(input_columns="image", operations=image_op_list, num_parallel_workers=num_workers)
  40. ds = ds.batch(batch_size, drop_remainder=drop_remainder, num_parallel_workers=num_workers, pad_info=pad_info)
  41. ds = ds.repeat(repeat_num)
  42. return ds
  43. def op_network_with_epoch(network, step_num):
  44. iter_num = 0
  45. network.set_train()
  46. for _ in range(step_num):
  47. op_return = network()
  48. op_return = op_return.asnumpy()
  49. logger.info("Op_return is : %s", op_return)
  50. iter_num += 1
  51. logger.info("Iter Num : %s", iter_num)
  52. return iter_num
  53. def convert_type(shapes, types):
  54. ms_types = []
  55. for np_shape, np_type in zip(shapes, types):
  56. input_np = np.zeros(np_shape, np_type)
  57. tensor = Tensor(input_np)
  58. ms_types.append(tensor.dtype)
  59. return ms_types
  60. def get_dataset_base_value(dataset):
  61. dataset_size = dataset.get_dataset_size()
  62. batch_size = dataset.get_batch_size()
  63. return dataset_size, batch_size
  64. def dataset_send_tdt(dataset):
  65. time.sleep(1)
  66. dataset.send(1)
  67. def get_dataset_shapes_and_types(dataset):
  68. dataset_shapes = dataset.output_shapes()
  69. np_types = dataset.output_types()
  70. dataset_types = convert_type(dataset_shapes, np_types)
  71. return dataset_shapes, dataset_types
  72. class SingleOpNetwork(nn.Cell):
  73. def __init__(self, shapes):
  74. super(SingleOpNetwork, self).__init__()
  75. self.shapes = tuple(shapes[0])
  76. self.Op_Reshape_network = P.Reshape()
  77. def construct(self, network_input):
  78. return self.Op_Reshape_network(network_input, self.shapes)
  79. class NetWithTDT(nn.Cell):
  80. def __init__(self, network, dataset_types, dataset_shapes, shared_name=''):
  81. super(NetWithTDT, self).__init__()
  82. self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_shapes), shared_name)
  83. self.Op_network = network
  84. def construct(self):
  85. next_input, _ = self.get_next()
  86. return self.Op_network(next_input)
  87. def op_network_with_step_num(dataset, step_num):
  88. dataset_shapes, dataset_types = get_dataset_shapes_and_types(dataset)
  89. _, batch_size = get_dataset_base_value(dataset)
  90. dataset = dataset.device_que()
  91. queue_name = dataset.queue_name
  92. net = SingleOpNetwork(dataset_shapes)
  93. net_with_dataset = NetWithTDT(net, dataset_types, dataset_shapes, queue_name)
  94. # when device type is Davinci, net should has get_next operation before call init_dataset
  95. _cell_graph_executor.init_dataset(dataset.queue_name, 1, batch_size, dataset_types, dataset_shapes, (), "")
  96. dataset_send_tdt(dataset)
  97. return op_network_with_epoch(net_with_dataset, step_num)
  98. @pytest.mark.level0
  99. @pytest.mark.platform_arm_ascend_training
  100. @pytest.mark.platform_x86_ascend_training
  101. @pytest.mark.env_onecard
  102. def test_tdt_consume_beyond_produce():
  103. context.set_context(mode=context.GRAPH_MODE)
  104. batch_size = 64
  105. repeat_num = 1
  106. num_rows = 640
  107. beyond_step_num = 1000
  108. ds = dataset_cifar(batch_size=batch_size, repeat_num=repeat_num, num_rows=num_rows)
  109. try:
  110. iter_num = op_network_with_step_num(ds, step_num=beyond_step_num)
  111. logger.info("out_iter_num:%s", iter_num)
  112. assert False
  113. except RuntimeError as e:
  114. logger.info("when dataset batch num is less than train loop, error msg is %s", e)
  115. assert True
  116. @pytest.mark.level0
  117. @pytest.mark.platform_arm_ascend_training
  118. @pytest.mark.platform_x86_ascend_training
  119. @pytest.mark.env_onecard
  120. def test_tdt_produce_beyond_consume():
  121. context.set_context(mode=context.GRAPH_MODE)
  122. batch_size = 64
  123. repeat_num = 1
  124. num_rows = 6400
  125. beyond_step_num = 10
  126. ds = dataset_cifar(batch_size=batch_size, repeat_num=repeat_num, num_rows=num_rows)
  127. iter_num = op_network_with_step_num(ds, step_num=beyond_step_num)
  128. logger.info("out_iter_num:%s", iter_num)
  129. assert iter_num == 10