You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

select_hard_feature_gs.py 2.1 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import pandas as pd
  2. import os
  3. import shutil
  4. import numpy as np
  5. from sklearn.covariance import LedoitWolf
  6. from scipy.spatial.distance import mahalanobis
  7. source_path = '/home/shanwei-luo/teamdata/anomaly_detection_active_learning/data0422/unlabel_11_12/'
  8. dist_path_01 = '/home/shanwei-luo/teamdata/anomaly_detection_active_learning/data0422/smd12_11_12_hard_score_04/train/'
  9. infer_data_unlabel=pd.read_csv('./test_unlabel_11_12.csv')
  10. print(infer_data_unlabel.shape)
  11. infer_data_train=pd.read_csv('./test_baseline_06_10.csv')
  12. print(infer_data_train.shape)
  13. '''infer_data.info()
  14. infer_data.describe()
  15. infer_data.head()'''
  16. train_feats = []
  17. for index, row in infer_data_train.iterrows():
  18. feat = row['feature'].split(",")
  19. feat[0] = feat[0][1:]
  20. feat[-1] = feat[-1][:-1]
  21. feat=list(map(float,feat))
  22. train_feats.append(feat)
  23. train_feats = np.array(train_feats)
  24. print(train_feats.shape)
  25. train_mean = np.mean(train_feats, axis=0)
  26. train_cov = LedoitWolf().fit(train_feats).covariance_
  27. train_cov_inv = np.linalg.inv(train_cov)
  28. print(train_mean.shape)
  29. print(train_cov.shape)
  30. print(train_cov_inv.shape)
  31. feat_dist = {}
  32. for index, row in infer_data_unlabel.iterrows():
  33. feat = row['feature'].split(",")
  34. feat[0] = feat[0][1:]
  35. feat[-1] = feat[-1][:-1]
  36. feat=list(map(float,feat))
  37. feat_dist[row['Image_Name']] = mahalanobis(feat, train_mean, train_cov_inv)
  38. feat_dist = sorted(feat_dist.items(), key=lambda x: x[1], reverse=True)
  39. #print(feat_dist)
  40. select_01 = []
  41. count = 0
  42. for k, v in feat_dist:
  43. if count<2750:
  44. select_01.append(k)
  45. #print(k, v)
  46. count += 1
  47. print(len(select_01))
  48. count_img = 0
  49. count_label = 0
  50. for file in select_01:
  51. shutil.copy(source_path+'images/'+file, dist_path_01+'images/'+file)
  52. count_img += 1
  53. if os.path.exists(source_path+'labels/'+file.replace(".jpg",".txt")):
  54. shutil.copy(source_path+'labels/'+file.replace(".jpg",".txt"), dist_path_01+'labels/'+file.replace(".jpg",".txt"))
  55. count_label += 1
  56. print(count_img, count_label)
  57. '''print(len(infer_data['feature'][0]))
  58. feat = infer_data['feature'][0].split(",")
  59. print(len(feat))
  60. print(feat[0])'''

No Description

Contributors (3)