{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "90e7b1d4", "metadata": {}, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'torch'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0myaml\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'" ] } ], "source": [ "# -*- coding: utf-8 -*-\n", "from __future__ import print_function, division\n", "\n", "# import sys\n", "# sys.path.append('/home/xujiahong/openI_benchmark/vechicle_reID_VechicleNet/')\n", "\n", "import time\n", "import yaml\n", "import pickle\n", "import torch\n", "import torch.nn as nn\n", "import numpy as np\n", "from torchvision import datasets,transforms\n", "import os\n", "import scipy.io\n", "from tqdm import tqdm\n", "from data_utils.model_train import ft_net\n", "from utils.util import get_stream_logger\n", "from config.mainconfig import OUTPUT_RESULT_DIR, CONFIG_PATH\n", "\n", "\n", "\n", "def fliplr(img):\n", " '''flip horizontal'''\n", " inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W\n", " img_flip = img.index_select(3,inv_idx)\n", " return img_flip\n", "\n", "def extract_feature(model, dataloaders, flip):\n", " features = torch.FloatTensor()\n", " count = 0\n", " for _, data in enumerate(tqdm(dataloaders),0):\n", " img, _ = data\n", " n, c, h, w = img.size()\n", " count += n\n", "\n", " input_img = img.cuda()\n", " ff = model(input_img)\n", "\n", " if flip:\n", " img = fliplr(img)\n", " input_img = img.cuda()\n", " outputs_flip = model(input_img)\n", " ff += outputs_flip\n", "\n", " fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)\n", " ff = ff.div(fnorm.expand_as(ff))\n", " #print(ff.shape)\n", " features = torch.cat((features,ff.data.cpu().float()), 0)\n", " #features = torch.cat((features,ff.data.float()), 0)\n", " return features\n", "\n", "\n", "def get_id(img_path):\n", " '''\n", " xjh: \n", " example of the name of the img: 0769_c013_00074310_0\n", " 0769 is the vehicleID, 013 is the cameraID, 00074310 is the frameID\n", " '''\n", " camera_id = []\n", " labels = []\n", " for path, _ in img_path:\n", " #filename = path.split('/')[-1]\n", " filename = os.path.basename(path) #get the name of images\n", " # Test Gallery Image\n", " if not 'c' in filename: \n", " labels.append(9999999)\n", " camera_id.append(9999999)\n", " else:\n", " #label = filename[0:4]\n", " label = filename[0:5] #for benchmark_person\n", " camera = filename.split('c')[1]\n", " if label[0:2]=='-1':\n", " labels.append(-1)\n", " else:\n", " labels.append(int(label))\n", " #camera_id.append(int(camera[0:3]))\n", " camera_id.append(int(camera[0:2]))#for benchmark_person\n", " #print(camera[0:3])\n", " return camera_id, labels\n", "\n", "\n", "def test(config_file_path:str, logger):\n", " #read config files\n", " with open(config_file_path, encoding='utf-8') as f:\n", " opts = yaml.load(f, Loader=yaml.SafeLoader)\n", "\n", " data_dir = opts['input']['dataset']['data_dir']\n", " name = \"trained_\" + opts['input']['config']['name']\n", " trained_model_name = name + \"_last.pth\"\n", " save_path = OUTPUT_RESULT_DIR\n", "\n", " nclass = opts['input']['config']['nclass']\n", " stride = opts['input']['config']['stride']\n", " pool = opts['input']['config']['pool']\n", " droprate = opts['input']['config']['droprate']\n", " inputsize= opts['input']['config']['inputsize']\n", " w = opts['input']['config']['w']\n", " h = opts['input']['config']['h']\n", " batchsize = opts['input']['config']['batchsize']\n", " flip = opts['test']['flip_test']\n", "\n", " trained_model_path = os.path.join(save_path, trained_model_name)\n", "\n", " ##############################load model#################################################\n", " ###self-train\n", " model = ft_net(class_num = nclass, droprate = droprate, stride=stride, init_model=None, pool = pool, return_f=False)\n", " \n", " try:\n", " model.load_state_dict(torch.load(trained_model_path))\n", " except:\n", " model = torch.nn.DataParallel(model)\n", " model.load_state_dict(torch.load(trained_model_path))\n", " model = model.module\n", " model.classifier.classifier = nn.Sequential() #model ends with feature extractor(output len is 512)\n", " # print(model)\n", " \n", " ##############################load dataset###############################################\n", " \n", " #transforms for input image h==w==299, inputsize==256\n", " if h == w:\n", " data_transforms = transforms.Compose([\n", " transforms.Resize( ( round(inputsize*1.1), round(inputsize*1.1)), interpolation=3),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", " ])\n", " else:\n", " data_transforms = transforms.Compose( [\n", " transforms.Resize((round(h*1.1), round(w*1.1)), interpolation=3), #Image.BICUBIC\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", " ])\n", "\n", " image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['bounding_box_test','query']}\n", " dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batchsize,\n", " shuffle=False, num_workers=8) for x in ['bounding_box_test','query']}\n", "\n", " #############################check GPU###################################################\n", " use_gpu = torch.cuda.is_available()\n", "\n", "\n", " #############################extract features############################################\n", " # Change to test mode\n", " model = model.eval()\n", " if use_gpu:\n", " model = model.cuda()\n", "\n", " gallery_path = image_datasets['bounding_box_test'].imgs\n", " query_path = image_datasets['query'].imgs\n", "\n", " gallery_cam,gallery_label = get_id(gallery_path)\n", " query_cam,query_label = get_id(query_path)\n", "\n", "\n", " gallery_label = np.asarray(gallery_label)\n", " query_label = np.asarray(query_label)\n", " gallery_cam = np.asarray(gallery_cam)\n", " query_cam = np.asarray(query_cam)\n", " print('Gallery Size: %d'%len(gallery_label))\n", " print('Query Size: %d'%len(query_label))\n", " # Extract feature\n", " since = time.time()\n", " with torch.no_grad():\n", " gallery_feature = extract_feature(model, dataloaders['bounding_box_test'], flip)\n", " query_feature = extract_feature(model, dataloaders['query'], flip)\n", " process_time = time.time() - since\n", " logger.info('total forward time: %.2f minutes'%(process_time/60))\n", " \n", " dist = 1-torch.mm(query_feature, torch.transpose(gallery_feature, 0, 1))\n", "\n", " # Save to Matlab for check\n", " extracted_feature = {'gallery_feature': gallery_feature.numpy(), 'gallery_label':gallery_label, 'gallery_cam':gallery_cam, \\\n", " 'query_feature': query_feature.numpy(), 'query_label':query_label, 'query_cam':query_cam}\n", "\n", " result_name = os.path.join(save_path, name+'_feature.mat')\n", " scipy.io.savemat(result_name, extracted_feature) \n", "\n", " return_dict = {}\n", "\n", " return_dict['dist'] = dist.numpy()\n", " return_dict['feature_example'] = query_feature[0].numpy()\n", " return_dict['gallery_label'] = gallery_label\n", " return_dict['gallery_cam'] = gallery_cam\n", " return_dict['query_label'] = query_label\n", " return_dict['query_cam'] = query_cam\n", "\n", " pickle.dump(return_dict, open(OUTPUT_RESULT_DIR+'test_result.pkl', 'wb'), protocol=4)\n", "\n", " return \n", "\n", " # eval_result = evaluator(result, logger)\n", " # full_table = display_eval_result(dict = eval_result)\n", " # logger.info(full_table)\n", "\n", "if __name__==\"__main__\":\n", " logger = get_stream_logger('TEST')\n", " test(CONFIG_PATH, logger)" ] }, { "cell_type": "code", "execution_count": null, "id": "c27b171e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }