| @@ -0,0 +1,500 @@ | |||
| #!/usr/bin/env python | |||
| __all__ = ['find_parameters'] | |||
| import os, sys, traceback, getpass, time, re | |||
| from threading import Thread | |||
| from subprocess import * | |||
| if sys.version_info[0] < 3: | |||
| from Queue import Queue | |||
| else: | |||
| from queue import Queue | |||
| telnet_workers = [] | |||
| ssh_workers = [] | |||
| nr_local_worker = 1 | |||
| class GridOption: | |||
| def __init__(self, dataset_pathname, options): | |||
| dirname = os.path.dirname(__file__) | |||
| if sys.platform != 'win32': | |||
| self.svmtrain_pathname = os.path.join(dirname, '../svm-train') | |||
| self.gnuplot_pathname = '/usr/bin/gnuplot' | |||
| else: | |||
| # example for windows | |||
| self.svmtrain_pathname = os.path.join(dirname, r'..\windows\svm-train.exe') | |||
| # svmtrain_pathname = r'c:\Program Files\libsvm\windows\svm-train.exe' | |||
| self.gnuplot_pathname = r'c:\tmp\gnuplot\binary\pgnuplot.exe' | |||
| self.fold = 5 | |||
| self.c_begin, self.c_end, self.c_step = -5, 15, 2 | |||
| self.g_begin, self.g_end, self.g_step = 3, -15, -2 | |||
| self.grid_with_c, self.grid_with_g = True, True | |||
| self.dataset_pathname = dataset_pathname | |||
| self.dataset_title = os.path.split(dataset_pathname)[1] | |||
| self.out_pathname = '{0}.out'.format(self.dataset_title) | |||
| self.png_pathname = '{0}.png'.format(self.dataset_title) | |||
| self.pass_through_string = ' ' | |||
| self.resume_pathname = None | |||
| self.parse_options(options) | |||
| def parse_options(self, options): | |||
| if type(options) == str: | |||
| options = options.split() | |||
| i = 0 | |||
| pass_through_options = [] | |||
| while i < len(options): | |||
| if options[i] == '-log2c': | |||
| i = i + 1 | |||
| if options[i] == 'null': | |||
| self.grid_with_c = False | |||
| else: | |||
| self.c_begin, self.c_end, self.c_step = map(float,options[i].split(',')) | |||
| elif options[i] == '-log2g': | |||
| i = i + 1 | |||
| if options[i] == 'null': | |||
| self.grid_with_g = False | |||
| else: | |||
| self.g_begin, self.g_end, self.g_step = map(float,options[i].split(',')) | |||
| elif options[i] == '-v': | |||
| i = i + 1 | |||
| self.fold = options[i] | |||
| elif options[i] in ('-c','-g'): | |||
| raise ValueError('Use -log2c and -log2g.') | |||
| elif options[i] == '-svmtrain': | |||
| i = i + 1 | |||
| self.svmtrain_pathname = options[i] | |||
| elif options[i] == '-gnuplot': | |||
| i = i + 1 | |||
| if options[i] == 'null': | |||
| self.gnuplot_pathname = None | |||
| else: | |||
| self.gnuplot_pathname = options[i] | |||
| elif options[i] == '-out': | |||
| i = i + 1 | |||
| if options[i] == 'null': | |||
| self.out_pathname = None | |||
| else: | |||
| self.out_pathname = options[i] | |||
| elif options[i] == '-png': | |||
| i = i + 1 | |||
| self.png_pathname = options[i] | |||
| elif options[i] == '-resume': | |||
| if i == (len(options)-1) or options[i+1].startswith('-'): | |||
| self.resume_pathname = self.dataset_title + '.out' | |||
| else: | |||
| i = i + 1 | |||
| self.resume_pathname = options[i] | |||
| else: | |||
| pass_through_options.append(options[i]) | |||
| i = i + 1 | |||
| self.pass_through_string = ' '.join(pass_through_options) | |||
| if not os.path.exists(self.svmtrain_pathname): | |||
| raise IOError('svm-train executable not found') | |||
| if not os.path.exists(self.dataset_pathname): | |||
| raise IOError('dataset not found') | |||
| if self.resume_pathname and not os.path.exists(self.resume_pathname): | |||
| raise IOError('file for resumption not found') | |||
| if not self.grid_with_c and not self.grid_with_g: | |||
| raise ValueError('-log2c and -log2g should not be null simultaneously') | |||
| if self.gnuplot_pathname and not os.path.exists(self.gnuplot_pathname): | |||
| sys.stderr.write('gnuplot executable not found\n') | |||
| self.gnuplot_pathname = None | |||
| def redraw(db,best_param,gnuplot,options,tofile=False): | |||
| if len(db) == 0: return | |||
| begin_level = round(max(x[2] for x in db)) - 3 | |||
| step_size = 0.5 | |||
| best_log2c,best_log2g,best_rate = best_param | |||
| # if newly obtained c, g, or cv values are the same, | |||
| # then stop redrawing the contour. | |||
| if all(x[0] == db[0][0] for x in db): return | |||
| if all(x[1] == db[0][1] for x in db): return | |||
| if all(x[2] == db[0][2] for x in db): return | |||
| if tofile: | |||
| gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n") | |||
| gnuplot.write("set output \"{0}\"\n".format(options.png_pathname.replace('\\','\\\\')).encode()) | |||
| #gnuplot.write(b"set term postscript color solid\n") | |||
| #gnuplot.write("set output \"{0}.ps\"\n".format(options.dataset_title).encode().encode()) | |||
| elif sys.platform == 'win32': | |||
| gnuplot.write(b"set term windows\n") | |||
| else: | |||
| gnuplot.write( b"set term x11\n") | |||
| gnuplot.write(b"set xlabel \"log2(C)\"\n") | |||
| gnuplot.write(b"set ylabel \"log2(gamma)\"\n") | |||
| gnuplot.write("set xrange [{0}:{1}]\n".format(options.c_begin,options.c_end).encode()) | |||
| gnuplot.write("set yrange [{0}:{1}]\n".format(options.g_begin,options.g_end).encode()) | |||
| gnuplot.write(b"set contour\n") | |||
| gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode()) | |||
| gnuplot.write(b"unset surface\n") | |||
| gnuplot.write(b"unset ztics\n") | |||
| gnuplot.write(b"set view 0,0\n") | |||
| gnuplot.write("set title \"{0}\"\n".format(options.dataset_title).encode()) | |||
| gnuplot.write(b"unset label\n") | |||
| gnuplot.write("set label \"Best log2(C) = {0} log2(gamma) = {1} accuracy = {2}%\" \ | |||
| at screen 0.5,0.85 center\n". \ | |||
| format(best_log2c, best_log2g, best_rate).encode()) | |||
| gnuplot.write("set label \"C = {0} gamma = {1}\"" | |||
| " at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode()) | |||
| gnuplot.write(b"set key at screen 0.9,0.9\n") | |||
| gnuplot.write(b"splot \"-\" with lines\n") | |||
| db.sort(key = lambda x:(x[0], -x[1])) | |||
| prevc = db[0][0] | |||
| for line in db: | |||
| if prevc != line[0]: | |||
| gnuplot.write(b"\n") | |||
| prevc = line[0] | |||
| gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode()) | |||
| gnuplot.write(b"e\n") | |||
| gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure | |||
| gnuplot.flush() | |||
| def calculate_jobs(options): | |||
| def range_f(begin,end,step): | |||
| # like range, but works on non-integer too | |||
| seq = [] | |||
| while True: | |||
| if step > 0 and begin > end: break | |||
| if step < 0 and begin < end: break | |||
| seq.append(begin) | |||
| begin = begin + step | |||
| return seq | |||
| def permute_sequence(seq): | |||
| n = len(seq) | |||
| if n <= 1: return seq | |||
| mid = int(n/2) | |||
| left = permute_sequence(seq[:mid]) | |||
| right = permute_sequence(seq[mid+1:]) | |||
| ret = [seq[mid]] | |||
| while left or right: | |||
| if left: ret.append(left.pop(0)) | |||
| if right: ret.append(right.pop(0)) | |||
| return ret | |||
| c_seq = permute_sequence(range_f(options.c_begin,options.c_end,options.c_step)) | |||
| g_seq = permute_sequence(range_f(options.g_begin,options.g_end,options.g_step)) | |||
| if not options.grid_with_c: | |||
| c_seq = [None] | |||
| if not options.grid_with_g: | |||
| g_seq = [None] | |||
| nr_c = float(len(c_seq)) | |||
| nr_g = float(len(g_seq)) | |||
| i, j = 0, 0 | |||
| jobs = [] | |||
| while i < nr_c or j < nr_g: | |||
| if i/nr_c < j/nr_g: | |||
| # increase C resolution | |||
| line = [] | |||
| for k in range(0,j): | |||
| line.append((c_seq[i],g_seq[k])) | |||
| i = i + 1 | |||
| jobs.append(line) | |||
| else: | |||
| # increase g resolution | |||
| line = [] | |||
| for k in range(0,i): | |||
| line.append((c_seq[k],g_seq[j])) | |||
| j = j + 1 | |||
| jobs.append(line) | |||
| resumed_jobs = {} | |||
| if options.resume_pathname is None: | |||
| return jobs, resumed_jobs | |||
| for line in open(options.resume_pathname, 'r'): | |||
| line = line.strip() | |||
| rst = re.findall(r'rate=([0-9.]+)',line) | |||
| if not rst: | |||
| continue | |||
| rate = float(rst[0]) | |||
| c, g = None, None | |||
| rst = re.findall(r'log2c=([0-9.-]+)',line) | |||
| if rst: | |||
| c = float(rst[0]) | |||
| rst = re.findall(r'log2g=([0-9.-]+)',line) | |||
| if rst: | |||
| g = float(rst[0]) | |||
| resumed_jobs[(c,g)] = rate | |||
| return jobs, resumed_jobs | |||
| class WorkerStopToken: # used to notify the worker to stop or if a worker is dead | |||
| pass | |||
| class Worker(Thread): | |||
| def __init__(self,name,job_queue,result_queue,options): | |||
| Thread.__init__(self) | |||
| self.name = name | |||
| self.job_queue = job_queue | |||
| self.result_queue = result_queue | |||
| self.options = options | |||
| def run(self): | |||
| while True: | |||
| (cexp,gexp) = self.job_queue.get() | |||
| if cexp is WorkerStopToken: | |||
| self.job_queue.put((cexp,gexp)) | |||
| # print('worker {0} stop.'.format(self.name)) | |||
| break | |||
| try: | |||
| c, g = None, None | |||
| if cexp != None: | |||
| c = 2.0**cexp | |||
| if gexp != None: | |||
| g = 2.0**gexp | |||
| rate = self.run_one(c,g) | |||
| if rate is None: raise RuntimeError('get no rate') | |||
| except: | |||
| # we failed, let others do that and we just quit | |||
| traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2]) | |||
| self.job_queue.put((cexp,gexp)) | |||
| sys.stderr.write('worker {0} quit.\n'.format(self.name)) | |||
| break | |||
| else: | |||
| self.result_queue.put((self.name,cexp,gexp,rate)) | |||
| def get_cmd(self,c,g): | |||
| options=self.options | |||
| cmdline = '"' + options.svmtrain_pathname + '"' | |||
| if options.grid_with_c: | |||
| cmdline += ' -c {0} '.format(c) | |||
| if options.grid_with_g: | |||
| cmdline += ' -g {0} '.format(g) | |||
| cmdline += ' -v {0} {1} {2} '.format\ | |||
| (options.fold,options.pass_through_string,options.dataset_pathname) | |||
| return cmdline | |||
| class LocalWorker(Worker): | |||
| def run_one(self,c,g): | |||
| cmdline = self.get_cmd(c,g) | |||
| result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout | |||
| for line in result.readlines(): | |||
| if str(line).find('Cross') != -1: | |||
| return float(line.split()[-1][0:-1]) | |||
| class SSHWorker(Worker): | |||
| def __init__(self,name,job_queue,result_queue,host,options): | |||
| Worker.__init__(self,name,job_queue,result_queue,options) | |||
| self.host = host | |||
| self.cwd = os.getcwd() | |||
| def run_one(self,c,g): | |||
| cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"'.format\ | |||
| (self.host,self.cwd,self.get_cmd(c,g)) | |||
| result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout | |||
| for line in result.readlines(): | |||
| if str(line).find('Cross') != -1: | |||
| return float(line.split()[-1][0:-1]) | |||
| class TelnetWorker(Worker): | |||
| def __init__(self,name,job_queue,result_queue,host,username,password,options): | |||
| Worker.__init__(self,name,job_queue,result_queue,options) | |||
| self.host = host | |||
| self.username = username | |||
| self.password = password | |||
| def run(self): | |||
| import telnetlib | |||
| self.tn = tn = telnetlib.Telnet(self.host) | |||
| tn.read_until('login: ') | |||
| tn.write(self.username + '\n') | |||
| tn.read_until('Password: ') | |||
| tn.write(self.password + '\n') | |||
| # XXX: how to know whether login is successful? | |||
| tn.read_until(self.username) | |||
| # | |||
| print('login ok', self.host) | |||
| tn.write('cd '+os.getcwd()+'\n') | |||
| Worker.run(self) | |||
| tn.write('exit\n') | |||
| def run_one(self,c,g): | |||
| cmdline = self.get_cmd(c,g) | |||
| result = self.tn.write(cmdline+'\n') | |||
| (idx,matchm,output) = self.tn.expect(['Cross.*\n']) | |||
| for line in output.split('\n'): | |||
| if str(line).find('Cross') != -1: | |||
| return float(line.split()[-1][0:-1]) | |||
| def find_parameters(dataset_pathname, options=''): | |||
| def update_param(c,g,rate,best_c,best_g,best_rate,worker,resumed): | |||
| if (rate > best_rate) or (rate==best_rate and g==best_g and c<best_c): | |||
| best_rate,best_c,best_g = rate,c,g | |||
| stdout_str = '[{0}] {1} {2} (best '.format\ | |||
| (worker,' '.join(str(x) for x in [c,g] if x is not None),rate) | |||
| output_str = '' | |||
| if c != None: | |||
| stdout_str += 'c={0}, '.format(2.0**best_c) | |||
| output_str += 'log2c={0} '.format(c) | |||
| if g != None: | |||
| stdout_str += 'g={0}, '.format(2.0**best_g) | |||
| output_str += 'log2g={0} '.format(g) | |||
| stdout_str += 'rate={0})'.format(best_rate) | |||
| print(stdout_str) | |||
| if options.out_pathname and not resumed: | |||
| output_str += 'rate={0}\n'.format(rate) | |||
| result_file.write(output_str) | |||
| result_file.flush() | |||
| return best_c,best_g,best_rate | |||
| options = GridOption(dataset_pathname, options); | |||
| if options.gnuplot_pathname: | |||
| gnuplot = Popen(options.gnuplot_pathname,stdin = PIPE,stdout=PIPE,stderr=PIPE).stdin | |||
| else: | |||
| gnuplot = None | |||
| # put jobs in queue | |||
| jobs,resumed_jobs = calculate_jobs(options) | |||
| job_queue = Queue(0) | |||
| result_queue = Queue(0) | |||
| for (c,g) in resumed_jobs: | |||
| result_queue.put(('resumed',c,g,resumed_jobs[(c,g)])) | |||
| for line in jobs: | |||
| for (c,g) in line: | |||
| if (c,g) not in resumed_jobs: | |||
| job_queue.put((c,g)) | |||
| # hack the queue to become a stack -- | |||
| # this is important when some thread | |||
| # failed and re-put a job. It we still | |||
| # use FIFO, the job will be put | |||
| # into the end of the queue, and the graph | |||
| # will only be updated in the end | |||
| job_queue._put = job_queue.queue.appendleft | |||
| # fire telnet workers | |||
| if telnet_workers: | |||
| nr_telnet_worker = len(telnet_workers) | |||
| username = getpass.getuser() | |||
| password = getpass.getpass() | |||
| for host in telnet_workers: | |||
| worker = TelnetWorker(host,job_queue,result_queue, | |||
| host,username,password,options) | |||
| worker.start() | |||
| # fire ssh workers | |||
| if ssh_workers: | |||
| for host in ssh_workers: | |||
| worker = SSHWorker(host,job_queue,result_queue,host,options) | |||
| worker.start() | |||
| # fire local workers | |||
| for i in range(nr_local_worker): | |||
| worker = LocalWorker('local',job_queue,result_queue,options) | |||
| worker.start() | |||
| # gather results | |||
| done_jobs = {} | |||
| if options.out_pathname: | |||
| if options.resume_pathname: | |||
| result_file = open(options.out_pathname, 'a') | |||
| else: | |||
| result_file = open(options.out_pathname, 'w') | |||
| db = [] | |||
| best_rate = -1 | |||
| best_c,best_g = None,None | |||
| for (c,g) in resumed_jobs: | |||
| rate = resumed_jobs[(c,g)] | |||
| best_c,best_g,best_rate = update_param(c,g,rate,best_c,best_g,best_rate,'resumed',True) | |||
| for line in jobs: | |||
| for (c,g) in line: | |||
| while (c,g) not in done_jobs: | |||
| (worker,c1,g1,rate1) = result_queue.get() | |||
| done_jobs[(c1,g1)] = rate1 | |||
| if (c1,g1) not in resumed_jobs: | |||
| best_c,best_g,best_rate = update_param(c1,g1,rate1,best_c,best_g,best_rate,worker,False) | |||
| db.append((c,g,done_jobs[(c,g)])) | |||
| if gnuplot and options.grid_with_c and options.grid_with_g: | |||
| redraw(db,[best_c, best_g, best_rate],gnuplot,options) | |||
| redraw(db,[best_c, best_g, best_rate],gnuplot,options,True) | |||
| if options.out_pathname: | |||
| result_file.close() | |||
| job_queue.put((WorkerStopToken,None)) | |||
| best_param, best_cg = {}, [] | |||
| if best_c != None: | |||
| best_param['c'] = 2.0**best_c | |||
| best_cg += [2.0**best_c] | |||
| if best_g != None: | |||
| best_param['g'] = 2.0**best_g | |||
| best_cg += [2.0**best_g] | |||
| print('{0} {1}'.format(' '.join(map(str,best_cg)), best_rate)) | |||
| return best_rate, best_param | |||
| if __name__ == '__main__': | |||
| def exit_with_help(): | |||
| print("""\ | |||
| Usage: grid.py [grid_options] [svm_options] dataset | |||
| grid_options : | |||
| -log2c {begin,end,step | "null"} : set the range of c (default -5,15,2) | |||
| begin,end,step -- c_range = 2^{begin,...,begin+k*step,...,end} | |||
| "null" -- do not grid with c | |||
| -log2g {begin,end,step | "null"} : set the range of g (default 3,-15,-2) | |||
| begin,end,step -- g_range = 2^{begin,...,begin+k*step,...,end} | |||
| "null" -- do not grid with g | |||
| -v n : n-fold cross validation (default 5) | |||
| -svmtrain pathname : set svm executable path and name | |||
| -gnuplot {pathname | "null"} : | |||
| pathname -- set gnuplot executable path and name | |||
| "null" -- do not plot | |||
| -out {pathname | "null"} : (default dataset.out) | |||
| pathname -- set output file path and name | |||
| "null" -- do not output file | |||
| -png pathname : set graphic output file path and name (default dataset.png) | |||
| -resume [pathname] : resume the grid task using an existing output file (default pathname is dataset.out) | |||
| This is experimental. Try this option only if some parameters have been checked for the SAME data. | |||
| svm_options : additional options for svm-train""") | |||
| sys.exit(1) | |||
| if len(sys.argv) < 2: | |||
| exit_with_help() | |||
| dataset_pathname = sys.argv[-1] | |||
| options = sys.argv[1:-1] | |||
| try: | |||
| find_parameters(dataset_pathname, options) | |||
| except (IOError,ValueError) as e: | |||
| sys.stderr.write(str(e) + '\n') | |||
| sys.stderr.write('Try "grid.py" for more information.\n') | |||
| sys.exit(1) | |||