You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tdt_data_ms.py 4.1 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import sys
  16. import numpy as np
  17. import mindspore.context as context
  18. import mindspore.dataset as ds
  19. import mindspore.dataset.vision.c_transforms as vision
  20. import mindspore.nn as nn
  21. from mindspore.common.api import _executor
  22. from mindspore.common.tensor import Tensor
  23. from mindspore.dataset.vision import Inter
  24. from mindspore.ops import operations as P
  25. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  26. data_path = sys.argv[1]
  27. SCHEMA_DIR = "{0}/resnet_all_datasetSchema.json".format(data_path)
  28. def test_me_de_train_dataset():
  29. data_list = ["{0}/train-00001-of-01024.data".format(data_path)]
  30. data_set_new = ds.TFRecordDataset(data_list, schema=SCHEMA_DIR,
  31. columns_list=["image/encoded", "image/class/label"])
  32. resize_height = 224
  33. resize_width = 224
  34. rescale = 1.0 / 255.0
  35. shift = 0.0
  36. # define map operations
  37. decode_op = vision.Decode()
  38. resize_op = vision.Resize((resize_height, resize_width),
  39. Inter.LINEAR) # Bilinear as default
  40. rescale_op = vision.Rescale(rescale, shift)
  41. # apply map operations on images
  42. data_set_new = data_set_new.map(operations=decode_op, input_columns="image/encoded")
  43. data_set_new = data_set_new.map(operations=resize_op, input_columns="image/encoded")
  44. data_set_new = data_set_new.map(operations=rescale_op, input_columns="image/encoded")
  45. hwc2chw_op = vision.HWC2CHW()
  46. data_set_new = data_set_new.map(operations=hwc2chw_op, input_columns="image/encoded")
  47. data_set_new = data_set_new.repeat(1)
  48. # apply batch operations
  49. batch_size_new = 32
  50. data_set_new = data_set_new.batch(batch_size_new, drop_remainder=True)
  51. return data_set_new
  52. def convert_type(shapes, types):
  53. ms_types = []
  54. for np_shape, np_type in zip(shapes, types):
  55. input_np = np.zeros(np_shape, np_type)
  56. tensor = Tensor(input_np)
  57. ms_types.append(tensor.dtype)
  58. return ms_types
  59. if __name__ == '__main__':
  60. data_set = test_me_de_train_dataset()
  61. dataset_size = data_set.get_dataset_size()
  62. batch_size = data_set.get_batch_size()
  63. dataset_shapes = data_set.output_shapes()
  64. np_types = data_set.output_types()
  65. dataset_types = convert_type(dataset_shapes, np_types)
  66. ds1 = data_set.device_que()
  67. get_next = P.GetNext(dataset_types, dataset_shapes, 2, ds1.queue_name)
  68. tadd = P.ReLU()
  69. class dataiter(nn.Cell):
  70. def construct(self):
  71. input_, _ = get_next()
  72. return tadd(input_)
  73. net = dataiter()
  74. net.set_train()
  75. _executor.init_dataset(ds1.queue_name, 39, batch_size,
  76. dataset_types, dataset_shapes, (), 'dataset')
  77. ds1.send()
  78. for data in data_set.create_tuple_iterator(output_numpy=True, num_epochs=1):
  79. output = net()
  80. print(data[0].any())
  81. print(
  82. "****************************************************************************************************")
  83. d = output.asnumpy()
  84. print(d)
  85. print(
  86. "end+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++",
  87. d.any())
  88. assert (
  89. (data[0] == d).all()), "TDT test execute failed, please check current code commit"
  90. print(
  91. "+++++++++++++++++++++++++++++++++++[INFO] Success+++++++++++++++++++++++++++++++++++++++++++")