You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

1Untitled.ipynb 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "id": "90e7b1d4",
  7. "metadata": {},
  8. "outputs": [
  9. {
  10. "ename": "ModuleNotFoundError",
  11. "evalue": "No module named 'torch'",
  12. "output_type": "error",
  13. "traceback": [
  14. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  15. "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
  16. "\u001b[0;32m<ipython-input-1-e39e8fd52943>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0myaml\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
  17. "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'"
  18. ]
  19. }
  20. ],
  21. "source": [
  22. "# -*- coding: utf-8 -*-\n",
  23. "from __future__ import print_function, division\n",
  24. "\n",
  25. "# import sys\n",
  26. "# sys.path.append('/home/xujiahong/openI_benchmark/vechicle_reID_VechicleNet/')\n",
  27. "\n",
  28. "import time\n",
  29. "import yaml\n",
  30. "import pickle\n",
  31. "import torch\n",
  32. "import torch.nn as nn\n",
  33. "import numpy as np\n",
  34. "from torchvision import datasets,transforms\n",
  35. "import os\n",
  36. "import scipy.io\n",
  37. "from tqdm import tqdm\n",
  38. "from data_utils.model_train import ft_net\n",
  39. "from utils.util import get_stream_logger\n",
  40. "from config.mainconfig import OUTPUT_RESULT_DIR, CONFIG_PATH\n",
  41. "\n",
  42. "\n",
  43. "\n",
  44. "def fliplr(img):\n",
  45. " '''flip horizontal'''\n",
  46. " inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W\n",
  47. " img_flip = img.index_select(3,inv_idx)\n",
  48. " return img_flip\n",
  49. "\n",
  50. "def extract_feature(model, dataloaders, flip):\n",
  51. " features = torch.FloatTensor()\n",
  52. " count = 0\n",
  53. " for _, data in enumerate(tqdm(dataloaders),0):\n",
  54. " img, _ = data\n",
  55. " n, c, h, w = img.size()\n",
  56. " count += n\n",
  57. "\n",
  58. " input_img = img.cuda()\n",
  59. " ff = model(input_img)\n",
  60. "\n",
  61. " if flip:\n",
  62. " img = fliplr(img)\n",
  63. " input_img = img.cuda()\n",
  64. " outputs_flip = model(input_img)\n",
  65. " ff += outputs_flip\n",
  66. "\n",
  67. " fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)\n",
  68. " ff = ff.div(fnorm.expand_as(ff))\n",
  69. " #print(ff.shape)\n",
  70. " features = torch.cat((features,ff.data.cpu().float()), 0)\n",
  71. " #features = torch.cat((features,ff.data.float()), 0)\n",
  72. " return features\n",
  73. "\n",
  74. "\n",
  75. "def get_id(img_path):\n",
  76. " '''\n",
  77. " xjh: \n",
  78. " example of the name of the img: 0769_c013_00074310_0\n",
  79. " 0769 is the vehicleID, 013 is the cameraID, 00074310 is the frameID\n",
  80. " '''\n",
  81. " camera_id = []\n",
  82. " labels = []\n",
  83. " for path, _ in img_path:\n",
  84. " #filename = path.split('/')[-1]\n",
  85. " filename = os.path.basename(path) #get the name of images\n",
  86. " # Test Gallery Image\n",
  87. " if not 'c' in filename: \n",
  88. " labels.append(9999999)\n",
  89. " camera_id.append(9999999)\n",
  90. " else:\n",
  91. " #label = filename[0:4]\n",
  92. " label = filename[0:5] #for benchmark_person\n",
  93. " camera = filename.split('c')[1]\n",
  94. " if label[0:2]=='-1':\n",
  95. " labels.append(-1)\n",
  96. " else:\n",
  97. " labels.append(int(label))\n",
  98. " #camera_id.append(int(camera[0:3]))\n",
  99. " camera_id.append(int(camera[0:2]))#for benchmark_person\n",
  100. " #print(camera[0:3])\n",
  101. " return camera_id, labels\n",
  102. "\n",
  103. "\n",
  104. "def test(config_file_path:str, logger):\n",
  105. " #read config files\n",
  106. " with open(config_file_path, encoding='utf-8') as f:\n",
  107. " opts = yaml.load(f, Loader=yaml.SafeLoader)\n",
  108. "\n",
  109. " data_dir = opts['input']['dataset']['data_dir']\n",
  110. " name = \"trained_\" + opts['input']['config']['name']\n",
  111. " trained_model_name = name + \"_last.pth\"\n",
  112. " save_path = OUTPUT_RESULT_DIR\n",
  113. "\n",
  114. " nclass = opts['input']['config']['nclass']\n",
  115. " stride = opts['input']['config']['stride']\n",
  116. " pool = opts['input']['config']['pool']\n",
  117. " droprate = opts['input']['config']['droprate']\n",
  118. " inputsize= opts['input']['config']['inputsize']\n",
  119. " w = opts['input']['config']['w']\n",
  120. " h = opts['input']['config']['h']\n",
  121. " batchsize = opts['input']['config']['batchsize']\n",
  122. " flip = opts['test']['flip_test']\n",
  123. "\n",
  124. " trained_model_path = os.path.join(save_path, trained_model_name)\n",
  125. "\n",
  126. " ##############################load model#################################################\n",
  127. " ###self-train\n",
  128. " model = ft_net(class_num = nclass, droprate = droprate, stride=stride, init_model=None, pool = pool, return_f=False)\n",
  129. " \n",
  130. " try:\n",
  131. " model.load_state_dict(torch.load(trained_model_path))\n",
  132. " except:\n",
  133. " model = torch.nn.DataParallel(model)\n",
  134. " model.load_state_dict(torch.load(trained_model_path))\n",
  135. " model = model.module\n",
  136. " model.classifier.classifier = nn.Sequential() #model ends with feature extractor(output len is 512)\n",
  137. " # print(model)\n",
  138. " \n",
  139. " ##############################load dataset###############################################\n",
  140. " \n",
  141. " #transforms for input image h==w==299, inputsize==256\n",
  142. " if h == w:\n",
  143. " data_transforms = transforms.Compose([\n",
  144. " transforms.Resize( ( round(inputsize*1.1), round(inputsize*1.1)), interpolation=3),\n",
  145. " transforms.ToTensor(),\n",
  146. " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
  147. " ])\n",
  148. " else:\n",
  149. " data_transforms = transforms.Compose( [\n",
  150. " transforms.Resize((round(h*1.1), round(w*1.1)), interpolation=3), #Image.BICUBIC\n",
  151. " transforms.ToTensor(),\n",
  152. " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
  153. " ])\n",
  154. "\n",
  155. " image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['bounding_box_test','query']}\n",
  156. " dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batchsize,\n",
  157. " shuffle=False, num_workers=8) for x in ['bounding_box_test','query']}\n",
  158. "\n",
  159. " #############################check GPU###################################################\n",
  160. " use_gpu = torch.cuda.is_available()\n",
  161. "\n",
  162. "\n",
  163. " #############################extract features############################################\n",
  164. " # Change to test mode\n",
  165. " model = model.eval()\n",
  166. " if use_gpu:\n",
  167. " model = model.cuda()\n",
  168. "\n",
  169. " gallery_path = image_datasets['bounding_box_test'].imgs\n",
  170. " query_path = image_datasets['query'].imgs\n",
  171. "\n",
  172. " gallery_cam,gallery_label = get_id(gallery_path)\n",
  173. " query_cam,query_label = get_id(query_path)\n",
  174. "\n",
  175. "\n",
  176. " gallery_label = np.asarray(gallery_label)\n",
  177. " query_label = np.asarray(query_label)\n",
  178. " gallery_cam = np.asarray(gallery_cam)\n",
  179. " query_cam = np.asarray(query_cam)\n",
  180. " print('Gallery Size: %d'%len(gallery_label))\n",
  181. " print('Query Size: %d'%len(query_label))\n",
  182. " # Extract feature\n",
  183. " since = time.time()\n",
  184. " with torch.no_grad():\n",
  185. " gallery_feature = extract_feature(model, dataloaders['bounding_box_test'], flip)\n",
  186. " query_feature = extract_feature(model, dataloaders['query'], flip)\n",
  187. " process_time = time.time() - since\n",
  188. " logger.info('total forward time: %.2f minutes'%(process_time/60))\n",
  189. " \n",
  190. " dist = 1-torch.mm(query_feature, torch.transpose(gallery_feature, 0, 1))\n",
  191. "\n",
  192. " # Save to Matlab for check\n",
  193. " extracted_feature = {'gallery_feature': gallery_feature.numpy(), 'gallery_label':gallery_label, 'gallery_cam':gallery_cam, \\\n",
  194. " 'query_feature': query_feature.numpy(), 'query_label':query_label, 'query_cam':query_cam}\n",
  195. "\n",
  196. " result_name = os.path.join(save_path, name+'_feature.mat')\n",
  197. " scipy.io.savemat(result_name, extracted_feature) \n",
  198. "\n",
  199. " return_dict = {}\n",
  200. "\n",
  201. " return_dict['dist'] = dist.numpy()\n",
  202. " return_dict['feature_example'] = query_feature[0].numpy()\n",
  203. " return_dict['gallery_label'] = gallery_label\n",
  204. " return_dict['gallery_cam'] = gallery_cam\n",
  205. " return_dict['query_label'] = query_label\n",
  206. " return_dict['query_cam'] = query_cam\n",
  207. "\n",
  208. " pickle.dump(return_dict, open(OUTPUT_RESULT_DIR+'test_result.pkl', 'wb'), protocol=4)\n",
  209. "\n",
  210. " return \n",
  211. "\n",
  212. " # eval_result = evaluator(result, logger)\n",
  213. " # full_table = display_eval_result(dict = eval_result)\n",
  214. " # logger.info(full_table)\n",
  215. "\n",
  216. "if __name__==\"__main__\":\n",
  217. " logger = get_stream_logger('TEST')\n",
  218. " test(CONFIG_PATH, logger)"
  219. ]
  220. },
  221. {
  222. "cell_type": "code",
  223. "execution_count": null,
  224. "id": "c27b171e",
  225. "metadata": {},
  226. "outputs": [],
  227. "source": []
  228. }
  229. ],
  230. "metadata": {
  231. "kernelspec": {
  232. "display_name": "MindSpore",
  233. "language": "python",
  234. "name": "mindspore"
  235. },
  236. "language_info": {
  237. "codemirror_mode": {
  238. "name": "ipython",
  239. "version": 3
  240. },
  241. "file_extension": ".py",
  242. "mimetype": "text/x-python",
  243. "name": "python",
  244. "nbconvert_exporter": "python",
  245. "pygments_lexer": "ipython3",
  246. "version": "3.7.6"
  247. }
  248. },
  249. "nbformat": 4,
  250. "nbformat_minor": 5
  251. }