You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

select_hard_multi_model.py 1.8 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import pandas as pd
  2. import os
  3. import shutil
  4. source_path = '/home/shanwei-luo/teamdata/anomaly_detection_active_learning/data0422/unlabel_11_12/'
  5. dist_path_01 = '/home/shanwei-luo/teamdata/anomaly_detection_active_learning/data0422/smd12_11_12_hard_score_03/train/'
  6. infer_data=pd.read_csv('./test_unlabel_11_12.csv')
  7. print(infer_data.shape)
  8. infer_data_cascade=pd.read_csv('./test_cascade_unlabel_11_12.csv')
  9. print(infer_data_cascade.shape)
  10. '''infer_data.info()
  11. infer_data.describe()
  12. infer_data.head()
  13. print(infer_data['score'])
  14. print(infer_data['Image_Name'])'''
  15. #infer_data = infer_data.sort_values('score',ascending=False)
  16. atss_score = {}
  17. for index, row in infer_data.iterrows():
  18. atss_score[row['Image_Name']] = row['score']
  19. cascade_score = {}
  20. for index, row in infer_data_cascade.iterrows():
  21. cascade_score[row['Image_Name']] = row['score']
  22. hard_score = {}
  23. for image_name in atss_score.keys():
  24. hard_score[image_name] = abs(atss_score[image_name] - cascade_score[image_name])
  25. #print(atss_score[image_name], cascade_score[image_name], hard_score[image_name])
  26. hard_score = sorted(hard_score.items(), key=lambda x: x[1], reverse=True)
  27. #print(hard_score)
  28. select_01 = []
  29. count = 0
  30. for k, v in hard_score:
  31. if count<2750:
  32. select_01.append(k)
  33. #print(k, v)
  34. count += 1
  35. print(len(select_01))
  36. count_img = 0
  37. count_label = 0
  38. for file in select_01:
  39. shutil.copy(source_path+'images/'+file, dist_path_01+'images/'+file)
  40. count_img += 1
  41. if os.path.exists(source_path+'labels/'+file.replace(".jpg",".txt")):
  42. shutil.copy(source_path+'labels/'+file.replace(".jpg",".txt"), dist_path_01+'labels/'+file.replace(".jpg",".txt"))
  43. count_label += 1
  44. print(count_img, count_label)
  45. '''print(len(infer_data['feature'][0]))
  46. feat = infer_data['feature'][0].split(",")
  47. print(len(feat))
  48. print(feat[0])'''

No Description

Contributors (3)