You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_mock.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. '''Remove after MindData merge to MindSpore '''
  16. import numpy as np
  17. from mindspore import Tensor
  18. class MindData:
  19. """ Stub for MindData """
  20. def __init__(self, size=None, batch_size=None, repeat_count=1,
  21. np_types=None, output_shapes=None, input_indexs=()):
  22. self._size = size
  23. self._batch_size = batch_size
  24. self._repeat_count = repeat_count
  25. self._np_types = np_types
  26. self._output_shapes = output_shapes
  27. self._input_indexs = input_indexs
  28. self._iter_num = 0
  29. def get_dataset_size(self):
  30. return self._size
  31. def get_repeat_count(self):
  32. return self._repeat_count
  33. def get_batch_size(self):
  34. return self._batch_size
  35. def output_types(self):
  36. return self._np_types
  37. def output_shapes(self):
  38. return self._output_shapes
  39. @property
  40. def input_indexs(self):
  41. return self._input_indexs
  42. def device_que(self):
  43. self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736'
  44. return self
  45. def send(self):
  46. pass
  47. def __len__(self):
  48. return self._size
  49. def __iter__(self):
  50. return self
  51. def __next__(self):
  52. if self._size < self._iter_num:
  53. raise StopIteration
  54. self._iter_num += 1
  55. next_value = []
  56. for shape, typ in zip(self._output_shapes, self._np_types):
  57. next_value.append(Tensor(np.ndarray(shape, typ)))
  58. return tuple(next_value)
  59. def next(self):
  60. return self.__next__()
  61. def reset(self):
  62. self._iter_num = 0