In [1]:
# -*- coding: utf-8 -*-
from __future__ import print_function, division

# import sys
# sys.path.append('/home/xujiahong/openI_benchmark/vechicle_reID_VechicleNet/')

import time
import yaml
import pickle
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets,transforms
import os
import scipy.io
from tqdm import tqdm
from data_utils.model_train import ft_net
from utils.util import get_stream_logger
from config.mainconfig import OUTPUT_RESULT_DIR, CONFIG_PATH



def fliplr(img):
    '''flip horizontal'''
    inv_idx = torch.arange(img.size(3)-1,-1,-1).long()  # N x C x H x W
    img_flip = img.index_select(3,inv_idx)
    return img_flip

def extract_feature(model, dataloaders, flip):
    features = torch.FloatTensor()
    count = 0
    for _, data in enumerate(tqdm(dataloaders),0):
        img, _ = data
        n, c, h, w = img.size()
        count += n

        input_img = img.cuda()
        ff = model(input_img)

        if flip:
            img = fliplr(img)
            input_img = img.cuda()
            outputs_flip = model(input_img)
            ff += outputs_flip

        fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
        ff = ff.div(fnorm.expand_as(ff))
        #print(ff.shape)
        features = torch.cat((features,ff.data.cpu().float()), 0)
        #features = torch.cat((features,ff.data.float()), 0)
    return features


def get_id(img_path):
    '''
    xjh: 
    example of the name of the img: 0769_c013_00074310_0
    0769 is the vehicleID, 013 is the cameraID,  00074310 is the frameID
    '''
    camera_id = []
    labels = []
    for path, _ in img_path:
        #filename = path.split('/')[-1]
        filename = os.path.basename(path) #get the name of images
        # Test Gallery Image
        if not 'c' in filename: 
            labels.append(9999999)
            camera_id.append(9999999)
        else:
            #label = filename[0:4]
            label = filename[0:5] #for benchmark_person
            camera = filename.split('c')[1]
            if label[0:2]=='-1':
                labels.append(-1)
            else:
                labels.append(int(label))
            #camera_id.append(int(camera[0:3]))
            camera_id.append(int(camera[0:2]))#for benchmark_person
        #print(camera[0:3])
    return camera_id, labels


def test(config_file_path:str, logger):
    #read config files
    with open(config_file_path, encoding='utf-8') as f:
        opts = yaml.load(f, Loader=yaml.SafeLoader)

    data_dir = opts['input']['dataset']['data_dir']
    name = "trained_" + opts['input']['config']['name']
    trained_model_name = name + "_last.pth"
    save_path = OUTPUT_RESULT_DIR

    nclass = opts['input']['config']['nclass']
    stride = opts['input']['config']['stride']
    pool = opts['input']['config']['pool']
    droprate = opts['input']['config']['droprate']
    inputsize= opts['input']['config']['inputsize']
    w = opts['input']['config']['w']
    h = opts['input']['config']['h']
    batchsize = opts['input']['config']['batchsize']
    flip = opts['test']['flip_test']

    trained_model_path = os.path.join(save_path, trained_model_name)

    ##############################load model#################################################
    ###self-train
    model =  ft_net(class_num = nclass, droprate = droprate, stride=stride, init_model=None, pool = pool, return_f=False)
    
    try:
        model.load_state_dict(torch.load(trained_model_path))
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(torch.load(trained_model_path))
        model = model.module
    model.classifier.classifier = nn.Sequential() #model ends with feature extractor(output len is 512)
    # print(model)
    
    ##############################load dataset###############################################
    
    #transforms for input image h==w==299, inputsize==256
    if h == w:
        data_transforms = transforms.Compose([
            transforms.Resize( ( round(inputsize*1.1), round(inputsize*1.1)), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        data_transforms = transforms.Compose( [
            transforms.Resize((round(h*1.1), round(w*1.1)), interpolation=3), #Image.BICUBIC
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

    image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['bounding_box_test','query']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batchsize,
                                             shuffle=False, num_workers=8) for x in ['bounding_box_test','query']}

    #############################check GPU###################################################
    use_gpu = torch.cuda.is_available()


    #############################extract features############################################
    # Change to test mode
    model = model.eval()
    if use_gpu:
        model = model.cuda()

    gallery_path = image_datasets['bounding_box_test'].imgs
    query_path = image_datasets['query'].imgs

    gallery_cam,gallery_label = get_id(gallery_path)
    query_cam,query_label = get_id(query_path)


    gallery_label = np.asarray(gallery_label)
    query_label = np.asarray(query_label)
    gallery_cam = np.asarray(gallery_cam)
    query_cam = np.asarray(query_cam)
    print('Gallery Size: %d'%len(gallery_label))
    print('Query Size: %d'%len(query_label))
    # Extract feature
    since = time.time()
    with torch.no_grad():
        gallery_feature = extract_feature(model, dataloaders['bounding_box_test'], flip)
        query_feature = extract_feature(model, dataloaders['query'], flip)
    process_time = time.time() - since
    logger.info('total forward time: %.2f minutes'%(process_time/60))
    
    dist = 1-torch.mm(query_feature, torch.transpose(gallery_feature, 0, 1))

    # Save to Matlab for check
    extracted_feature = {'gallery_feature': gallery_feature.numpy(), 'gallery_label':gallery_label, 'gallery_cam':gallery_cam, \
                        'query_feature': query_feature.numpy(), 'query_label':query_label, 'query_cam':query_cam}

    result_name = os.path.join(save_path, name+'_feature.mat')
    scipy.io.savemat(result_name, extracted_feature)        

    return_dict = {}

    return_dict['dist'] = dist.numpy()
    return_dict['feature_example'] = query_feature[0].numpy()
    return_dict['gallery_label'] = gallery_label
    return_dict['gallery_cam'] = gallery_cam
    return_dict['query_label'] = query_label
    return_dict['query_cam'] = query_cam

    pickle.dump(return_dict, open(OUTPUT_RESULT_DIR+'test_result.pkl', 'wb'), protocol=4)

    return 

    # eval_result = evaluator(result, logger)
    # full_table = display_eval_result(dict = eval_result)
    # logger.info(full_table)

if __name__=="__main__":
    logger = get_stream_logger('TEST')
    test(CONFIG_PATH, logger)

ModuleNotFoundError: No module named 'torch'