| @@ -0,0 +1,120 @@ | |||
| #!/usr/bin/env python | |||
| import os, sys, math, random | |||
| from collections import defaultdict | |||
| if sys.version_info[0] >= 3: | |||
| xrange = range | |||
| def exit_with_help(argv): | |||
| print("""\ | |||
| Usage: {0} [options] dataset subset_size [output1] [output2] | |||
| This script randomly selects a subset of the dataset. | |||
| options: | |||
| -s method : method of selection (default 0) | |||
| 0 -- stratified selection (classification only) | |||
| 1 -- random selection | |||
| output1 : the subset (optional) | |||
| output2 : rest of the data (optional) | |||
| If output1 is omitted, the subset will be printed on the screen.""".format(argv[0])) | |||
| exit(1) | |||
| def process_options(argv): | |||
| argc = len(argv) | |||
| if argc < 3: | |||
| exit_with_help(argv) | |||
| # default method is stratified selection | |||
| method = 0 | |||
| subset_file = sys.stdout | |||
| rest_file = None | |||
| i = 1 | |||
| while i < argc: | |||
| if argv[i][0] != "-": | |||
| break | |||
| if argv[i] == "-s": | |||
| i = i + 1 | |||
| method = int(argv[i]) | |||
| if method not in [0,1]: | |||
| print("Unknown selection method {0}".format(method)) | |||
| exit_with_help(argv) | |||
| i = i + 1 | |||
| dataset = argv[i] | |||
| subset_size = int(argv[i+1]) | |||
| if i+2 < argc: | |||
| subset_file = open(argv[i+2],'w') | |||
| if i+3 < argc: | |||
| rest_file = open(argv[i+3],'w') | |||
| return dataset, subset_size, method, subset_file, rest_file | |||
| def random_selection(dataset, subset_size): | |||
| l = sum(1 for line in open(dataset,'r')) | |||
| return sorted(random.sample(xrange(l), subset_size)) | |||
| def stratified_selection(dataset, subset_size): | |||
| labels = [line.split(None,1)[0] for line in open(dataset)] | |||
| label_linenums = defaultdict(list) | |||
| for i, label in enumerate(labels): | |||
| label_linenums[label] += [i] | |||
| l = len(labels) | |||
| remaining = subset_size | |||
| ret = [] | |||
| # classes with fewer data are sampled first; otherwise | |||
| # some rare classes may not be selected | |||
| for label in sorted(label_linenums, key=lambda x: len(label_linenums[x])): | |||
| linenums = label_linenums[label] | |||
| label_size = len(linenums) | |||
| # at least one instance per class | |||
| s = int(min(remaining, max(1, math.ceil(label_size*(float(subset_size)/l))))) | |||
| if s == 0: | |||
| sys.stderr.write('''\ | |||
| Error: failed to have at least one instance per class | |||
| 1. You may have regression data. | |||
| 2. Your classification data is unbalanced or too small. | |||
| Please use -s 1. | |||
| ''') | |||
| sys.exit(-1) | |||
| remaining -= s | |||
| ret += [linenums[i] for i in random.sample(xrange(label_size), s)] | |||
| return sorted(ret) | |||
| def main(argv=sys.argv): | |||
| dataset, subset_size, method, subset_file, rest_file = process_options(argv) | |||
| #uncomment the following line to fix the random seed | |||
| #random.seed(0) | |||
| selected_lines = [] | |||
| if method == 0: | |||
| selected_lines = stratified_selection(dataset, subset_size) | |||
| elif method == 1: | |||
| selected_lines = random_selection(dataset, subset_size) | |||
| #select instances based on selected_lines | |||
| dataset = open(dataset,'r') | |||
| prev_selected_linenum = -1 | |||
| for i in xrange(len(selected_lines)): | |||
| for cnt in xrange(selected_lines[i]-prev_selected_linenum-1): | |||
| line = dataset.readline() | |||
| if rest_file: | |||
| rest_file.write(line) | |||
| subset_file.write(dataset.readline()) | |||
| prev_selected_linenum = selected_lines[i] | |||
| subset_file.close() | |||
| if rest_file: | |||
| for line in dataset: | |||
| rest_file.write(line) | |||
| rest_file.close() | |||
| dataset.close() | |||
| if __name__ == '__main__': | |||
| main(sys.argv) | |||