|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "90e7b1d4",
- "metadata": {},
- "outputs": [
- {
- "ename": "ModuleNotFoundError",
- "evalue": "No module named 'torch'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-1-e39e8fd52943>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0myaml\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'"
- ]
- }
- ],
- "source": [
- "# -*- coding: utf-8 -*-\n",
- "from __future__ import print_function, division\n",
- "\n",
- "# import sys\n",
- "# sys.path.append('/home/xujiahong/openI_benchmark/vechicle_reID_VechicleNet/')\n",
- "\n",
- "import time\n",
- "import yaml\n",
- "import pickle\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "import numpy as np\n",
- "from torchvision import datasets,transforms\n",
- "import os\n",
- "import scipy.io\n",
- "from tqdm import tqdm\n",
- "from data_utils.model_train import ft_net\n",
- "from utils.util import get_stream_logger\n",
- "from config.mainconfig import OUTPUT_RESULT_DIR, CONFIG_PATH\n",
- "\n",
- "\n",
- "\n",
- "def fliplr(img):\n",
- " '''flip horizontal'''\n",
- " inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W\n",
- " img_flip = img.index_select(3,inv_idx)\n",
- " return img_flip\n",
- "\n",
- "def extract_feature(model, dataloaders, flip):\n",
- " features = torch.FloatTensor()\n",
- " count = 0\n",
- " for _, data in enumerate(tqdm(dataloaders),0):\n",
- " img, _ = data\n",
- " n, c, h, w = img.size()\n",
- " count += n\n",
- "\n",
- " input_img = img.cuda()\n",
- " ff = model(input_img)\n",
- "\n",
- " if flip:\n",
- " img = fliplr(img)\n",
- " input_img = img.cuda()\n",
- " outputs_flip = model(input_img)\n",
- " ff += outputs_flip\n",
- "\n",
- " fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)\n",
- " ff = ff.div(fnorm.expand_as(ff))\n",
- " #print(ff.shape)\n",
- " features = torch.cat((features,ff.data.cpu().float()), 0)\n",
- " #features = torch.cat((features,ff.data.float()), 0)\n",
- " return features\n",
- "\n",
- "\n",
- "def get_id(img_path):\n",
- " '''\n",
- " xjh: \n",
- " example of the name of the img: 0769_c013_00074310_0\n",
- " 0769 is the vehicleID, 013 is the cameraID, 00074310 is the frameID\n",
- " '''\n",
- " camera_id = []\n",
- " labels = []\n",
- " for path, _ in img_path:\n",
- " #filename = path.split('/')[-1]\n",
- " filename = os.path.basename(path) #get the name of images\n",
- " # Test Gallery Image\n",
- " if not 'c' in filename: \n",
- " labels.append(9999999)\n",
- " camera_id.append(9999999)\n",
- " else:\n",
- " #label = filename[0:4]\n",
- " label = filename[0:5] #for benchmark_person\n",
- " camera = filename.split('c')[1]\n",
- " if label[0:2]=='-1':\n",
- " labels.append(-1)\n",
- " else:\n",
- " labels.append(int(label))\n",
- " #camera_id.append(int(camera[0:3]))\n",
- " camera_id.append(int(camera[0:2]))#for benchmark_person\n",
- " #print(camera[0:3])\n",
- " return camera_id, labels\n",
- "\n",
- "\n",
- "def test(config_file_path:str, logger):\n",
- " #read config files\n",
- " with open(config_file_path, encoding='utf-8') as f:\n",
- " opts = yaml.load(f, Loader=yaml.SafeLoader)\n",
- "\n",
- " data_dir = opts['input']['dataset']['data_dir']\n",
- " name = \"trained_\" + opts['input']['config']['name']\n",
- " trained_model_name = name + \"_last.pth\"\n",
- " save_path = OUTPUT_RESULT_DIR\n",
- "\n",
- " nclass = opts['input']['config']['nclass']\n",
- " stride = opts['input']['config']['stride']\n",
- " pool = opts['input']['config']['pool']\n",
- " droprate = opts['input']['config']['droprate']\n",
- " inputsize= opts['input']['config']['inputsize']\n",
- " w = opts['input']['config']['w']\n",
- " h = opts['input']['config']['h']\n",
- " batchsize = opts['input']['config']['batchsize']\n",
- " flip = opts['test']['flip_test']\n",
- "\n",
- " trained_model_path = os.path.join(save_path, trained_model_name)\n",
- "\n",
- " ##############################load model#################################################\n",
- " ###self-train\n",
- " model = ft_net(class_num = nclass, droprate = droprate, stride=stride, init_model=None, pool = pool, return_f=False)\n",
- " \n",
- " try:\n",
- " model.load_state_dict(torch.load(trained_model_path))\n",
- " except:\n",
- " model = torch.nn.DataParallel(model)\n",
- " model.load_state_dict(torch.load(trained_model_path))\n",
- " model = model.module\n",
- " model.classifier.classifier = nn.Sequential() #model ends with feature extractor(output len is 512)\n",
- " # print(model)\n",
- " \n",
- " ##############################load dataset###############################################\n",
- " \n",
- " #transforms for input image h==w==299, inputsize==256\n",
- " if h == w:\n",
- " data_transforms = transforms.Compose([\n",
- " transforms.Resize( ( round(inputsize*1.1), round(inputsize*1.1)), interpolation=3),\n",
- " transforms.ToTensor(),\n",
- " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
- " ])\n",
- " else:\n",
- " data_transforms = transforms.Compose( [\n",
- " transforms.Resize((round(h*1.1), round(w*1.1)), interpolation=3), #Image.BICUBIC\n",
- " transforms.ToTensor(),\n",
- " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
- " ])\n",
- "\n",
- " image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['bounding_box_test','query']}\n",
- " dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batchsize,\n",
- " shuffle=False, num_workers=8) for x in ['bounding_box_test','query']}\n",
- "\n",
- " #############################check GPU###################################################\n",
- " use_gpu = torch.cuda.is_available()\n",
- "\n",
- "\n",
- " #############################extract features############################################\n",
- " # Change to test mode\n",
- " model = model.eval()\n",
- " if use_gpu:\n",
- " model = model.cuda()\n",
- "\n",
- " gallery_path = image_datasets['bounding_box_test'].imgs\n",
- " query_path = image_datasets['query'].imgs\n",
- "\n",
- " gallery_cam,gallery_label = get_id(gallery_path)\n",
- " query_cam,query_label = get_id(query_path)\n",
- "\n",
- "\n",
- " gallery_label = np.asarray(gallery_label)\n",
- " query_label = np.asarray(query_label)\n",
- " gallery_cam = np.asarray(gallery_cam)\n",
- " query_cam = np.asarray(query_cam)\n",
- " print('Gallery Size: %d'%len(gallery_label))\n",
- " print('Query Size: %d'%len(query_label))\n",
- " # Extract feature\n",
- " since = time.time()\n",
- " with torch.no_grad():\n",
- " gallery_feature = extract_feature(model, dataloaders['bounding_box_test'], flip)\n",
- " query_feature = extract_feature(model, dataloaders['query'], flip)\n",
- " process_time = time.time() - since\n",
- " logger.info('total forward time: %.2f minutes'%(process_time/60))\n",
- " \n",
- " dist = 1-torch.mm(query_feature, torch.transpose(gallery_feature, 0, 1))\n",
- "\n",
- " # Save to Matlab for check\n",
- " extracted_feature = {'gallery_feature': gallery_feature.numpy(), 'gallery_label':gallery_label, 'gallery_cam':gallery_cam, \\\n",
- " 'query_feature': query_feature.numpy(), 'query_label':query_label, 'query_cam':query_cam}\n",
- "\n",
- " result_name = os.path.join(save_path, name+'_feature.mat')\n",
- " scipy.io.savemat(result_name, extracted_feature) \n",
- "\n",
- " return_dict = {}\n",
- "\n",
- " return_dict['dist'] = dist.numpy()\n",
- " return_dict['feature_example'] = query_feature[0].numpy()\n",
- " return_dict['gallery_label'] = gallery_label\n",
- " return_dict['gallery_cam'] = gallery_cam\n",
- " return_dict['query_label'] = query_label\n",
- " return_dict['query_cam'] = query_cam\n",
- "\n",
- " pickle.dump(return_dict, open(OUTPUT_RESULT_DIR+'test_result.pkl', 'wb'), protocol=4)\n",
- "\n",
- " return \n",
- "\n",
- " # eval_result = evaluator(result, logger)\n",
- " # full_table = display_eval_result(dict = eval_result)\n",
- " # logger.info(full_table)\n",
- "\n",
- "if __name__==\"__main__\":\n",
- " logger = get_stream_logger('TEST')\n",
- " test(CONFIG_PATH, logger)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c27b171e",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "MindSpore",
- "language": "python",
- "name": "mindspore"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
- }
|