You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_mock.py 2.5 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. '''Remove after MindData merge to MindSpore '''
  16. import numpy as np
  17. from mindspore import Tensor
  18. class MindData:
  19. """ Stub for MindData """
  20. def __init__(self, size=None, batch_size=None, repeat_count=1,
  21. np_types=None, output_shapes=None, input_indexs=()):
  22. self._size = size
  23. self._batch_size = batch_size
  24. self._repeat_count = repeat_count
  25. self._np_types = np_types
  26. self._output_shapes = output_shapes
  27. self._input_indexs = input_indexs
  28. self._iter_num = 0
  29. def get_dataset_size(self):
  30. return self._size
  31. def get_repeat_count(self):
  32. return self._repeat_count
  33. def get_batch_size(self):
  34. return self._batch_size
  35. def output_types(self):
  36. return self._np_types
  37. def output_shapes(self):
  38. return self._output_shapes
  39. @property
  40. def input_indexs(self):
  41. return self._input_indexs
  42. def device_que(self, send_epoch_end=True):
  43. self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736'
  44. self.send_epoch_end = send_epoch_end
  45. return self
  46. def create_tuple_iterator(self, num_epochs=-1):
  47. return self.__iter__()
  48. def send(self, num_epochs=-1):
  49. pass
  50. def stop_send(self):
  51. pass
  52. def continue_send(self):
  53. pass
  54. def __len__(self):
  55. return self._size
  56. def __iter__(self):
  57. return self
  58. def __next__(self):
  59. if self._size < self._iter_num:
  60. raise StopIteration
  61. self._iter_num += 1
  62. next_value = []
  63. for shape, typ in zip(self._output_shapes, self._np_types):
  64. next_value.append(Tensor(np.ndarray(shape, typ)))
  65. return tuple(next_value)
  66. def next(self):
  67. return self.__next__()
  68. def reset(self):
  69. self._iter_num = 0