You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_general_data.py 21 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. import copy
  2. import numpy as np
  3. import pytest
  4. import torch
  5. from mmdet.core import GeneralData, InstanceData
  6. def _equal(a, b):
  7. if isinstance(a, (torch.Tensor, np.ndarray)):
  8. return (a == b).all()
  9. else:
  10. return a == b
  11. def test_general_data():
  12. # test init
  13. meta_info = dict(
  14. img_size=[256, 256],
  15. path='dadfaff',
  16. scale_factor=np.array([1.5, 1.5]),
  17. img_shape=torch.rand(4))
  18. data = dict(
  19. bboxes=torch.rand(4, 4),
  20. labels=torch.rand(4),
  21. masks=np.random.rand(4, 2, 2))
  22. instance_data = GeneralData(meta_info=meta_info)
  23. assert 'img_size' in instance_data
  24. assert instance_data.img_size == [256, 256]
  25. assert instance_data['img_size'] == [256, 256]
  26. assert 'path' in instance_data
  27. assert instance_data.path == 'dadfaff'
  28. # test nice_repr
  29. repr_instance_data = instance_data.new(data=data)
  30. nice_repr = str(repr_instance_data)
  31. for line in nice_repr.split('\n'):
  32. if 'masks' in line:
  33. assert 'shape' in line
  34. assert '(4, 2, 2)' in line
  35. if 'bboxes' in line:
  36. assert 'shape' in line
  37. assert 'torch.Size([4, 4])' in line
  38. if 'path' in line:
  39. assert 'dadfaff' in line
  40. if 'scale_factor' in line:
  41. assert '[1.5 1.5]' in line
  42. instance_data = GeneralData(
  43. meta_info=meta_info, data=dict(bboxes=torch.rand(5)))
  44. assert 'bboxes' in instance_data
  45. assert len(instance_data.bboxes) == 5
  46. # data should be a dict
  47. with pytest.raises(AssertionError):
  48. GeneralData(data=1)
  49. # test set data
  50. instance_data = GeneralData()
  51. instance_data.set_data(data)
  52. assert 'bboxes' in instance_data
  53. assert len(instance_data.bboxes) == 4
  54. assert 'masks' in instance_data
  55. assert len(instance_data.masks) == 4
  56. # data should be a dict
  57. with pytest.raises(AssertionError):
  58. instance_data.set_data(data=1)
  59. # test set_meta
  60. instance_data = GeneralData()
  61. instance_data.set_meta_info(meta_info)
  62. assert 'img_size' in instance_data
  63. assert instance_data.img_size == [256, 256]
  64. assert instance_data['img_size'] == [256, 256]
  65. assert 'path' in instance_data
  66. assert instance_data.path == 'dadfaff'
  67. # can skip same value when overwrite
  68. instance_data.set_meta_info(meta_info)
  69. # meta should be a dict
  70. with pytest.raises(AssertionError):
  71. instance_data.set_meta_info(meta_info='fjhka')
  72. # attribute in `_meta_info_field` is immutable once initialized
  73. instance_data.set_meta_info(meta_info)
  74. # meta should be immutable
  75. with pytest.raises(KeyError):
  76. instance_data.set_meta_info(dict(img_size=[254, 251]))
  77. with pytest.raises(KeyError):
  78. duplicate_meta_info = copy.deepcopy(meta_info)
  79. duplicate_meta_info['path'] = 'dada'
  80. instance_data.set_meta_info(duplicate_meta_info)
  81. with pytest.raises(KeyError):
  82. duplicate_meta_info = copy.deepcopy(meta_info)
  83. duplicate_meta_info['scale_factor'] = np.array([1.5, 1.6])
  84. instance_data.set_meta_info(duplicate_meta_info)
  85. # test new_instance_data
  86. instance_data = GeneralData(meta_info)
  87. new_instance_data = instance_data.new()
  88. for k, v in instance_data.meta_info_items():
  89. assert k in new_instance_data
  90. _equal(v, new_instance_data[k])
  91. instance_data = GeneralData(meta_info, data=data)
  92. temp_meta = copy.deepcopy(meta_info)
  93. temp_data = copy.deepcopy(data)
  94. temp_data['time'] = '12212'
  95. temp_meta['img_norm'] = np.random.random(3)
  96. new_instance_data = instance_data.new(meta_info=temp_meta, data=temp_data)
  97. for k, v in new_instance_data.meta_info_items():
  98. if k in instance_data:
  99. _equal(v, instance_data[k])
  100. else:
  101. assert _equal(v, temp_meta[k])
  102. assert k == 'img_norm'
  103. for k, v in new_instance_data.items():
  104. if k in instance_data:
  105. _equal(v, instance_data[k])
  106. else:
  107. assert k == 'time'
  108. assert _equal(v, temp_data[k])
  109. # test keys
  110. instance_data = GeneralData(meta_info, data=dict(bboxes=10))
  111. assert 'bboxes' in instance_data.keys()
  112. instance_data.b = 10
  113. assert 'b' in instance_data
  114. # test meta keys
  115. instance_data = GeneralData(meta_info, data=dict(bboxes=10))
  116. assert 'path' in instance_data.meta_info_keys()
  117. assert len(instance_data.meta_info_keys()) == len(meta_info)
  118. instance_data.set_meta_info(dict(workdir='fafaf'))
  119. assert 'workdir' in instance_data
  120. assert len(instance_data.meta_info_keys()) == len(meta_info) + 1
  121. # test values
  122. instance_data = GeneralData(meta_info, data=dict(bboxes=10))
  123. assert 10 in instance_data.values()
  124. assert len(instance_data.values()) == 1
  125. # test meta values
  126. instance_data = GeneralData(meta_info, data=dict(bboxes=10))
  127. # torch 1.3 eq() can not compare str and tensor
  128. from mmdet import digit_version
  129. if digit_version(torch.__version__) >= [1, 4]:
  130. assert 'dadfaff' in instance_data.meta_info_values()
  131. assert len(instance_data.meta_info_values()) == len(meta_info)
  132. # test items
  133. instance_data = GeneralData(data=data)
  134. for k, v in instance_data.items():
  135. assert k in data
  136. assert _equal(v, data[k])
  137. # test meta_info_items
  138. instance_data = GeneralData(meta_info=meta_info)
  139. for k, v in instance_data.meta_info_items():
  140. assert k in meta_info
  141. assert _equal(v, meta_info[k])
  142. # test __setattr__
  143. new_instance_data = GeneralData(data=data)
  144. new_instance_data.mask = torch.rand(3, 4, 5)
  145. new_instance_data.bboxes = torch.rand(2, 4)
  146. assert 'mask' in new_instance_data
  147. assert len(new_instance_data.mask) == 3
  148. assert len(new_instance_data.bboxes) == 2
  149. # test instance_data_field has been updated
  150. assert 'mask' in new_instance_data._data_fields
  151. assert 'bboxes' in new_instance_data._data_fields
  152. for k in data:
  153. assert k in new_instance_data._data_fields
  154. # '_meta_info_field', '_data_fields' is immutable.
  155. with pytest.raises(AttributeError):
  156. new_instance_data._data_fields = None
  157. with pytest.raises(AttributeError):
  158. new_instance_data._meta_info_fields = None
  159. with pytest.raises(AttributeError):
  160. del new_instance_data._data_fields
  161. with pytest.raises(AttributeError):
  162. del new_instance_data._meta_info_fields
  163. # key in _meta_info_field is immutable
  164. new_instance_data.set_meta_info(meta_info)
  165. with pytest.raises(KeyError):
  166. del new_instance_data.img_size
  167. with pytest.raises(KeyError):
  168. del new_instance_data.scale_factor
  169. for k in new_instance_data.meta_info_keys():
  170. with pytest.raises(AttributeError):
  171. new_instance_data[k] = None
  172. # test __delattr__
  173. # test key can be removed in instance_data_field
  174. assert 'mask' in new_instance_data._data_fields
  175. assert 'mask' in new_instance_data.keys()
  176. assert 'mask' in new_instance_data
  177. assert hasattr(new_instance_data, 'mask')
  178. del new_instance_data.mask
  179. assert 'mask' not in new_instance_data.keys()
  180. assert 'mask' not in new_instance_data
  181. assert 'mask' not in new_instance_data._data_fields
  182. assert not hasattr(new_instance_data, 'mask')
  183. # tset __delitem__
  184. new_instance_data.mask = torch.rand(1, 2, 3)
  185. assert 'mask' in new_instance_data._data_fields
  186. assert 'mask' in new_instance_data
  187. assert hasattr(new_instance_data, 'mask')
  188. del new_instance_data['mask']
  189. assert 'mask' not in new_instance_data
  190. assert 'mask' not in new_instance_data._data_fields
  191. assert 'mask' not in new_instance_data
  192. assert not hasattr(new_instance_data, 'mask')
  193. # test __setitem__
  194. new_instance_data['mask'] = torch.rand(1, 2, 3)
  195. assert 'mask' in new_instance_data._data_fields
  196. assert 'mask' in new_instance_data.keys()
  197. assert hasattr(new_instance_data, 'mask')
  198. # test data_fields has been updated
  199. assert 'mask' in new_instance_data.keys()
  200. assert 'mask' in new_instance_data._data_fields
  201. # '_meta_info_field', '_data_fields' is immutable.
  202. with pytest.raises(AttributeError):
  203. del new_instance_data['_data_fields']
  204. with pytest.raises(AttributeError):
  205. del new_instance_data['_meta_info_field']
  206. # test __getitem__
  207. new_instance_data.mask is new_instance_data['mask']
  208. # test get
  209. assert new_instance_data.get('mask') is new_instance_data.mask
  210. assert new_instance_data.get('none_attribute', None) is None
  211. assert new_instance_data.get('none_attribute', 1) == 1
  212. # test pop
  213. mask = new_instance_data.mask
  214. assert new_instance_data.pop('mask') is mask
  215. assert new_instance_data.pop('mask', None) is None
  216. assert new_instance_data.pop('mask', 1) == 1
  217. # '_meta_info_field', '_data_fields' is immutable.
  218. with pytest.raises(KeyError):
  219. new_instance_data.pop('_data_fields')
  220. with pytest.raises(KeyError):
  221. new_instance_data.pop('_meta_info_field')
  222. # attribute in `_meta_info_field` is immutable
  223. with pytest.raises(KeyError):
  224. new_instance_data.pop('img_size')
  225. # test pop attribute in instance_data_filed
  226. new_instance_data['mask'] = torch.rand(1, 2, 3)
  227. new_instance_data.pop('mask')
  228. # test data_field has been updated
  229. assert 'mask' not in new_instance_data
  230. assert 'mask' not in new_instance_data._data_fields
  231. assert 'mask' not in new_instance_data
  232. # test_keys
  233. new_instance_data.mask = torch.ones(1, 2, 3)
  234. 'mask' in new_instance_data.keys()
  235. has_flag = False
  236. for key in new_instance_data.keys():
  237. if key == 'mask':
  238. has_flag = True
  239. assert has_flag
  240. # test values
  241. assert len(list(new_instance_data.keys())) == len(
  242. list(new_instance_data.values()))
  243. mask = new_instance_data.mask
  244. has_flag = False
  245. for value in new_instance_data.values():
  246. if value is mask:
  247. has_flag = True
  248. assert has_flag
  249. # test items
  250. assert len(list(new_instance_data.keys())) == len(
  251. list(new_instance_data.items()))
  252. mask = new_instance_data.mask
  253. has_flag = False
  254. for key, value in new_instance_data.items():
  255. if value is mask:
  256. assert key == 'mask'
  257. has_flag = True
  258. assert has_flag
  259. # test device
  260. new_instance_data = GeneralData()
  261. if torch.cuda.is_available():
  262. newnew_instance_data = new_instance_data.new()
  263. devices = ('cpu', 'cuda')
  264. for i in range(10):
  265. device = devices[i % 2]
  266. newnew_instance_data[f'{i}'] = torch.rand(1, 2, 3, device=device)
  267. newnew_instance_data = newnew_instance_data.cpu()
  268. for value in newnew_instance_data.values():
  269. assert not value.is_cuda
  270. newnew_instance_data = new_instance_data.new()
  271. devices = ('cuda', 'cpu')
  272. for i in range(10):
  273. device = devices[i % 2]
  274. newnew_instance_data[f'{i}'] = torch.rand(1, 2, 3, device=device)
  275. newnew_instance_data = newnew_instance_data.cuda()
  276. for value in newnew_instance_data.values():
  277. assert value.is_cuda
  278. # test to
  279. double_instance_data = instance_data.new()
  280. double_instance_data.long = torch.LongTensor(1, 2, 3, 4)
  281. double_instance_data.bool = torch.BoolTensor(1, 2, 3, 4)
  282. double_instance_data = instance_data.to(torch.double)
  283. for k, v in double_instance_data.items():
  284. if isinstance(v, torch.Tensor):
  285. assert v.dtype is torch.double
  286. # test .cpu() .cuda()
  287. if torch.cuda.is_available():
  288. cpu_instance_data = double_instance_data.new()
  289. cpu_instance_data.mask = torch.rand(1)
  290. cuda_tensor = torch.rand(1, 2, 3).cuda()
  291. cuda_instance_data = cpu_instance_data.to(cuda_tensor.device)
  292. for value in cuda_instance_data.values():
  293. assert value.is_cuda
  294. cpu_instance_data = cuda_instance_data.cpu()
  295. for value in cpu_instance_data.values():
  296. assert not value.is_cuda
  297. cuda_instance_data = cpu_instance_data.cuda()
  298. for value in cuda_instance_data.values():
  299. assert value.is_cuda
  300. # test detach
  301. grad_instance_data = double_instance_data.new()
  302. grad_instance_data.mask = torch.rand(2, requires_grad=True)
  303. grad_instance_data.mask_1 = torch.rand(2, requires_grad=True)
  304. detach_instance_data = grad_instance_data.detach()
  305. for value in detach_instance_data.values():
  306. assert not value.requires_grad
  307. # test numpy
  308. tensor_instance_data = double_instance_data.new()
  309. tensor_instance_data.mask = torch.rand(2, requires_grad=True)
  310. tensor_instance_data.mask_1 = torch.rand(2, requires_grad=True)
  311. numpy_instance_data = tensor_instance_data.numpy()
  312. for value in numpy_instance_data.values():
  313. assert isinstance(value, np.ndarray)
  314. if torch.cuda.is_available():
  315. tensor_instance_data = double_instance_data.new()
  316. tensor_instance_data.mask = torch.rand(2)
  317. tensor_instance_data.mask_1 = torch.rand(2)
  318. tensor_instance_data = tensor_instance_data.cuda()
  319. numpy_instance_data = tensor_instance_data.numpy()
  320. for value in numpy_instance_data.values():
  321. assert isinstance(value, np.ndarray)
  322. instance_data['_c'] = 10000
  323. instance_data.get('dad', None) is None
  324. assert hasattr(instance_data, '_c')
  325. del instance_data['_c']
  326. assert not hasattr(instance_data, '_c')
  327. instance_data.a = 1000
  328. instance_data['a'] = 2000
  329. assert instance_data['a'] == 2000
  330. assert instance_data.a == 2000
  331. assert instance_data.get('a') == instance_data['a'] == instance_data.a
  332. instance_data._meta = 1000
  333. assert '_meta' in instance_data.keys()
  334. if torch.cuda.is_available():
  335. instance_data.bbox = torch.ones(2, 3, 4, 5).cuda()
  336. instance_data.score = torch.ones(2, 3, 4, 4)
  337. else:
  338. instance_data.bbox = torch.ones(2, 3, 4, 5)
  339. assert len(instance_data.new().keys()) == 0
  340. with pytest.raises(AttributeError):
  341. instance_data.img_size = 100
  342. for k, v in instance_data.items():
  343. if k == 'bbox':
  344. assert isinstance(v, torch.Tensor)
  345. assert 'a' in instance_data
  346. instance_data.pop('a')
  347. assert 'a' not in instance_data
  348. cpu_instance_data = instance_data.cpu()
  349. for k, v in cpu_instance_data.items():
  350. if isinstance(v, torch.Tensor):
  351. assert not v.is_cuda
  352. assert isinstance(cpu_instance_data.numpy().bbox, np.ndarray)
  353. if torch.cuda.is_available():
  354. cuda_resutls = instance_data.cuda()
  355. for k, v in cuda_resutls.items():
  356. if isinstance(v, torch.Tensor):
  357. assert v.is_cuda
  358. def test_instance_data():
  359. meta_info = dict(
  360. img_size=(256, 256),
  361. path='dadfaff',
  362. scale_factor=np.array([1.5, 1.5, 1, 1]))
  363. data = dict(
  364. bboxes=torch.rand(4, 4),
  365. masks=torch.rand(4, 2, 2),
  366. labels=np.random.rand(4),
  367. size=[(i, i) for i in range(4)])
  368. # test init
  369. instance_data = InstanceData(meta_info)
  370. assert 'path' in instance_data
  371. instance_data = InstanceData(meta_info, data=data)
  372. assert len(instance_data) == 4
  373. instance_data.set_data(data)
  374. assert len(instance_data) == 4
  375. meta_info = copy.deepcopy(meta_info)
  376. meta_info['img_name'] = 'flag'
  377. # test newinstance_data
  378. new_instance_data = instance_data.new(meta_info=meta_info)
  379. for k, v in new_instance_data.meta_info_items():
  380. if k in instance_data:
  381. _equal(v, instance_data[k])
  382. else:
  383. assert _equal(v, meta_info[k])
  384. assert k == 'img_name'
  385. # meta info is immutable
  386. with pytest.raises(KeyError):
  387. meta_info = copy.deepcopy(meta_info)
  388. meta_info['path'] = 'fdasfdsd'
  389. instance_data.new(meta_info=meta_info)
  390. # data fields should have same length
  391. with pytest.raises(AssertionError):
  392. temp_data = copy.deepcopy(data)
  393. temp_data['bboxes'] = torch.rand(5, 4)
  394. instance_data.new(data=temp_data)
  395. temp_data = copy.deepcopy(data)
  396. temp_data['scores'] = torch.rand(4)
  397. new_instance_data = instance_data.new(data=temp_data)
  398. for k, v in new_instance_data.items():
  399. if k in instance_data:
  400. _equal(v, instance_data[k])
  401. else:
  402. assert k == 'scores'
  403. assert _equal(v, temp_data[k])
  404. instance_data = instance_data.new()
  405. # test __setattr__
  406. # '_meta_info_field', '_data_fields' is immutable.
  407. with pytest.raises(AttributeError):
  408. instance_data._data_fields = dict()
  409. with pytest.raises(AttributeError):
  410. instance_data._data_fields = dict()
  411. # all attribute in instance_data_field should be
  412. # (torch.Tensor, np.ndarray, list))
  413. with pytest.raises(AssertionError):
  414. instance_data.a = 1000
  415. # instance_data field should has same length
  416. new_instance_data = instance_data.new()
  417. new_instance_data.det_bbox = torch.rand(100, 4)
  418. new_instance_data.det_label = torch.arange(100)
  419. with pytest.raises(AssertionError):
  420. new_instance_data.scores = torch.rand(101, 1)
  421. new_instance_data.none = [None] * 100
  422. with pytest.raises(AssertionError):
  423. new_instance_data.scores = [None] * 101
  424. new_instance_data.numpy_det = np.random.random([100, 1])
  425. with pytest.raises(AssertionError):
  426. new_instance_data.scores = np.random.random([101, 1])
  427. # isinstance(str, slice, int, torch.LongTensor, torch.BoolTensor)
  428. item = torch.Tensor([1, 2, 3, 4])
  429. with pytest.raises(AssertionError):
  430. new_instance_data[item]
  431. len(new_instance_data[item.long()]) == 1
  432. # when input is a bool tensor, The shape of
  433. # the input at index 0 should equal to
  434. # the value length in instance_data_field
  435. with pytest.raises(AssertionError):
  436. new_instance_data[item.bool()]
  437. for i in range(len(new_instance_data)):
  438. assert new_instance_data[i].det_label == i
  439. assert len(new_instance_data[i]) == 1
  440. # assert the index should in 0 ~ len(instance_data) -1
  441. with pytest.raises(IndexError):
  442. new_instance_data[101]
  443. # assert the index should not be an empty tensor
  444. new_new_instance_data = new_instance_data.new()
  445. with pytest.raises(AssertionError):
  446. new_new_instance_data[0]
  447. # test str
  448. with pytest.raises(AssertionError):
  449. instance_data.img_size_dummmy = meta_info['img_size']
  450. # test slice
  451. ten_ressults = new_instance_data[:10]
  452. len(ten_ressults) == 10
  453. for v in ten_ressults.values():
  454. assert len(v) == 10
  455. # test Longtensor
  456. long_tensor = torch.randint(100, (50, ))
  457. long_index_instance_data = new_instance_data[long_tensor]
  458. assert len(long_index_instance_data) == len(long_tensor)
  459. for key, value in long_index_instance_data.items():
  460. if not isinstance(value, list):
  461. assert (long_index_instance_data[key] == new_instance_data[key]
  462. [long_tensor]).all()
  463. else:
  464. len(long_tensor) == len(value)
  465. # test bool tensor
  466. bool_tensor = torch.rand(100) > 0.5
  467. bool_index_instance_data = new_instance_data[bool_tensor]
  468. assert len(bool_index_instance_data) == bool_tensor.sum()
  469. for key, value in bool_index_instance_data.items():
  470. if not isinstance(value, list):
  471. assert (bool_index_instance_data[key] == new_instance_data[key]
  472. [bool_tensor]).all()
  473. else:
  474. assert len(value) == bool_tensor.sum()
  475. num_instance = 1000
  476. instance_data_list = []
  477. # assert len(instance_lists) > 0
  478. with pytest.raises(AssertionError):
  479. instance_data.cat(instance_data_list)
  480. for _ in range(2):
  481. instance_data['bbox'] = torch.rand(num_instance, 4)
  482. instance_data['label'] = torch.rand(num_instance, 1)
  483. instance_data['mask'] = torch.rand(num_instance, 224, 224)
  484. instance_data['instances_infos'] = [1] * num_instance
  485. instance_data['cpu_bbox'] = np.random.random((num_instance, 4))
  486. if torch.cuda.is_available():
  487. instance_data.cuda_tensor = torch.rand(num_instance).cuda()
  488. assert instance_data.cuda_tensor.is_cuda
  489. cuda_instance_data = instance_data.cuda()
  490. assert cuda_instance_data.cuda_tensor.is_cuda
  491. assert len(instance_data[0]) == 1
  492. with pytest.raises(IndexError):
  493. return instance_data[num_instance + 1]
  494. with pytest.raises(AssertionError):
  495. instance_data.centerness = torch.rand(num_instance + 1, 1)
  496. mask_tensor = torch.rand(num_instance) > 0.5
  497. length = mask_tensor.sum()
  498. assert len(instance_data[mask_tensor]) == length
  499. index_tensor = torch.LongTensor([1, 5, 8, 110, 399])
  500. length = len(index_tensor)
  501. assert len(instance_data[index_tensor]) == length
  502. instance_data_list.append(instance_data)
  503. cat_resutls = InstanceData.cat(instance_data_list)
  504. assert len(cat_resutls) == num_instance * 2
  505. instances = InstanceData(data=dict(bboxes=torch.rand(4, 4)))
  506. # cat only single instance
  507. assert len(InstanceData.cat([instances])) == 4

No Description

Contributors (3)