| @@ -4,7 +4,7 @@ | |||
| Created on Wed Oct 20 11:48:02 2020 | |||
| @author: ljia | |||
| """ | |||
| """ | |||
| # This script tests the influence of the ratios between node costs and edge costs on the stability of the GED computation, where the base edit costs are [1, 1, 1, 1, 1, 1]. | |||
| import os | |||
| @@ -13,15 +13,15 @@ import pickle | |||
| import logging | |||
| from gklearn.ged.util import compute_geds | |||
| import time | |||
| from utils import get_dataset | |||
| from utils import get_dataset, set_edit_cost_consts | |||
| import sys | |||
| from group_results import group_trials | |||
| from group_results import group_trials, check_group_existence, update_group_marker | |||
| def xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial): | |||
| save_file_suffix = '.' + ds_name + '.num_sols_' + str(num_solutions) + '.ratio_' + "{:.2f}".format(ratio) + '.trial_' + str(trial) | |||
| # Return if the file exists. | |||
| if os.path.isfile(save_dir + 'ged_matrix' + save_file_suffix + '.pkl'): | |||
| return None, None | |||
| @@ -41,8 +41,11 @@ def xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial): | |||
| 'threads': multiprocessing.cpu_count(), | |||
| 'init_option': 'EAGER_WITHOUT_SHUFFLED_COPIES' | |||
| } | |||
| edit_cost_constants = [i * ratio for i in [1, 1, 1]] + [1, 1, 1] | |||
| edit_cost_constants = set_edit_cost_consts(ratio, | |||
| node_labeled=len(dataset.node_labels), | |||
| edge_labeled=len(dataset.edge_labels), | |||
| mode='uniform') | |||
| # edit_cost_constants = [item * 0.01 for item in edit_cost_constants] | |||
| # pickle.dump(edit_cost_constants, open(save_dir + "edit_costs" + save_file_suffix + ".pkl", "wb")) | |||
| @@ -53,7 +56,7 @@ def xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial): | |||
| options['node_attrs'] = dataset.node_attrs | |||
| options['edge_attrs'] = dataset.edge_attrs | |||
| parallel = True # if num_solutions == 1 else False | |||
| """**5. Compute GED matrix.**""" | |||
| ged_mat = 'error' | |||
| runtime = 0 | |||
| @@ -67,9 +70,9 @@ def xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial): | |||
| logging.basicConfig(filename=LOG_FILENAME, level=logging.DEBUG) | |||
| logging.exception(save_file_suffix) | |||
| print(repr(exp)) | |||
| """**6. Get results.**""" | |||
| with open(save_dir + 'ged_matrix' + save_file_suffix + '.pkl', 'wb') as f: | |||
| pickle.dump(ged_mat, f) | |||
| with open(save_dir + 'runtime' + save_file_suffix + '.pkl', 'wb') as f: | |||
| @@ -77,66 +80,76 @@ def xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial): | |||
| return ged_mat, runtime | |||
| def save_trials_as_group(dataset, ds_name, num_solutions, ratio): | |||
| # Return if the group file exists. | |||
| name_middle = '.' + ds_name + '.num_sols_' + str(num_solutions) + '.ratio_' + "{:.2f}".format(ratio) + '.' | |||
| name_group = save_dir + 'groups/ged_mats' + name_middle + 'npy' | |||
| if os.path.isfile(name_group): | |||
| if check_group_existence(name_group): | |||
| return | |||
| ged_mats = [] | |||
| runtimes = [] | |||
| for trial in range(1, 101): | |||
| num_trials = 100 | |||
| for trial in range(1, num_trials + 1): | |||
| print() | |||
| print('Trial:', trial) | |||
| ged_mat, runtime = xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial) | |||
| ged_mats.append(ged_mat) | |||
| runtimes.append(runtime) | |||
| # Group trials and Remove single files. | |||
| # @todo: if the program stops between the following lines, then there may be errors. | |||
| name_prefix = 'ged_matrix' + name_middle | |||
| group_trials(save_dir, name_prefix, True, True, False) | |||
| group_trials(save_dir, name_prefix, True, True, False, num_trials=num_trials) | |||
| name_prefix = 'runtime' + name_middle | |||
| group_trials(save_dir, name_prefix, True, True, False) | |||
| group_trials(save_dir, name_prefix, True, True, False, num_trials=num_trials) | |||
| update_group_marker(name_group) | |||
| def results_for_a_dataset(ds_name): | |||
| """**1. Get dataset.**""" | |||
| dataset = get_dataset(ds_name) | |||
| for num_solutions in num_solutions_list: | |||
| for ratio in ratio_list: | |||
| print() | |||
| print('# of solutions:', num_solutions) | |||
| for ratio in ratio_list: | |||
| print('Ratio:', ratio) | |||
| for num_solutions in num_solutions_list: | |||
| print() | |||
| print('Ratio:', ratio) | |||
| print('# of solutions:', num_solutions) | |||
| save_trials_as_group(dataset, ds_name, num_solutions, ratio) | |||
| def get_param_lists(ds_name): | |||
| def get_param_lists(ds_name, test=False): | |||
| if test: | |||
| num_solutions_list = [1, 10, 20, 30, 40, 50] | |||
| ratio_list = [10] | |||
| return num_solutions_list, ratio_list | |||
| if ds_name == 'AIDS_symb': | |||
| num_solutions_list = [1, 20, 40, 60, 80, 100] | |||
| ratio_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1, 3, 5, 7, 9] | |||
| else: | |||
| num_solutions_list = [1, 20, 40, 60, 80, 100] | |||
| ratio_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1, 3, 5, 7, 9] | |||
| num_solutions_list = [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] # [1, 20, 40, 60, 80, 100] | |||
| ratio_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1, 3, 5, 7, 9, 10][::-1] | |||
| return num_solutions_list, ratio_list | |||
| if __name__ == '__main__': | |||
| if len(sys.argv) > 1: | |||
| ds_name_list = sys.argv[1:] | |||
| else: | |||
| ds_name_list = ['MAO', 'Monoterpenoides', 'MUTAG', 'AIDS_symb'] | |||
| save_dir = 'outputs/edit_costs.num_sols.ratios.IPFP/' | |||
| ds_name_list = ['Acyclic', 'Alkane_unlabeled', 'MAO_lite', 'Monoterpenoides', 'MUTAG'] | |||
| # ds_name_list = ['Acyclic'] # 'Alkane_unlabeled'] | |||
| # ds_name_list = ['Acyclic', 'MAO', 'Monoterpenoides', 'MUTAG', 'AIDS_symb'] | |||
| save_dir = 'outputs/edit_costs.real_data.num_sols.ratios.IPFP/' | |||
| os.makedirs(save_dir, exist_ok=True) | |||
| os.makedirs(save_dir + 'groups/', exist_ok=True) | |||
| for ds_name in ds_name_list: | |||
| print() | |||
| print('Dataset:', ds_name) | |||
| num_solutions_list, ratio_list = get_param_lists(ds_name) | |||
| num_solutions_list, ratio_list = get_param_lists(ds_name, test=False) | |||
| results_for_a_dataset(ds_name) | |||
| @@ -5,7 +5,7 @@ Created on Thu Oct 29 17:26:43 2020 | |||
| @author: ljia | |||
| This script groups results together into a single file for the sake of faster | |||
| This script groups results together into a single file for the sake of faster | |||
| searching and loading. | |||
| """ | |||
| import os | |||
| @@ -16,9 +16,55 @@ from tqdm import tqdm | |||
| import sys | |||
| def check_group_existence(file_name): | |||
| path, name = os.path.split(file_name) | |||
| marker_fn = os.path.join(path, 'group_names_finished.pkl') | |||
| if os.path.isfile(marker_fn): | |||
| with open(marker_fn, 'rb') as f: | |||
| fns = pickle.load(f) | |||
| if name in fns: | |||
| return True | |||
| if os.path.isfile(file_name): | |||
| return True | |||
| return False | |||
| def update_group_marker(file_name): | |||
| path, name = os.path.split(file_name) | |||
| marker_fn = os.path.join(path, 'group_names_finished.pkl') | |||
| if os.path.isfile(marker_fn): | |||
| with open(marker_fn, 'rb') as f: | |||
| fns = pickle.loads(f) | |||
| if name in fns: | |||
| return | |||
| else: | |||
| fns.add(name) | |||
| else: | |||
| fns = set({name}) | |||
| with open(marker_fn, 'wb') as f: | |||
| pickle.dump(fns, f) | |||
| def create_group_marker_file(dir_folder, overwrite=True): | |||
| if not overwrite: | |||
| return | |||
| fns = set() | |||
| for file in sorted(os.listdir(dir_folder)): | |||
| if os.path.isfile(os.path.join(dir_folder, file)): | |||
| if file.endswith('.npy'): | |||
| fns.add(file) | |||
| marker_fn = os.path.join(dir_folder, 'group_names_finished.pkl') | |||
| with open(marker_fn, 'wb') as f: | |||
| pickle.dump(fns, f) | |||
| # This function is used by other scripts. Modify it carefully. | |||
| def group_trials(dir_folder, name_prefix, override, clear, backup): | |||
| def group_trials(dir_folder, name_prefix, overwrite, clear, backup, num_trials=100): | |||
| # Get group name. | |||
| label_name = name_prefix.split('.')[0] | |||
| if label_name == 'ged_matrix': | |||
| @@ -33,10 +79,10 @@ def group_trials(dir_folder, name_prefix, override, clear, backup): | |||
| else: | |||
| name_group = dir_folder + 'groups/' + group_label + name_suffix + 'pkl' | |||
| if not override and os.path.isfile(name_group): | |||
| if not overwrite and os.path.isfile(name_group): | |||
| # Check if all trial files exist. | |||
| trials_complete = True | |||
| for trial in range(1, 101): | |||
| for trial in range(1, num_trials + 1): | |||
| file_name = dir_folder + name_prefix + 'trial_' + str(trial) + '.pkl' | |||
| if not os.path.isfile(file_name): | |||
| trials_complete = False | |||
| @@ -44,7 +90,7 @@ def group_trials(dir_folder, name_prefix, override, clear, backup): | |||
| else: | |||
| # Get data. | |||
| data_group = [] | |||
| for trial in range(1, 101): | |||
| for trial in range(1, num_trials + 1): | |||
| file_name = dir_folder + name_prefix + 'trial_' + str(trial) + '.pkl' | |||
| if os.path.isfile(file_name): | |||
| with open(file_name, 'rb') as f: | |||
| @@ -64,7 +110,7 @@ def group_trials(dir_folder, name_prefix, override, clear, backup): | |||
| else: # Not all trials are completed. | |||
| return | |||
| # Write groups. | |||
| if label_name == 'ged_matrix': | |||
| data_group = np.array(data_group) | |||
| @@ -73,31 +119,31 @@ def group_trials(dir_folder, name_prefix, override, clear, backup): | |||
| else: | |||
| with open(name_group, 'wb') as f: | |||
| pickle.dump(data_group, f) | |||
| trials_complete = True | |||
| if trials_complete: | |||
| # Backup. | |||
| if backup: | |||
| for trial in range(1, 101): | |||
| for trial in range(1, num_trials + 1): | |||
| src = dir_folder + name_prefix + 'trial_' + str(trial) + '.pkl' | |||
| dst = dir_folder + 'backups/' + name_prefix + 'trial_' + str(trial) + '.pkl' | |||
| copyfile(src, dst) | |||
| # Clear. | |||
| if clear: | |||
| for trial in range(1, 101): | |||
| for trial in range(1, num_trials + 1): | |||
| src = dir_folder + name_prefix + 'trial_' + str(trial) + '.pkl' | |||
| os.remove(src) | |||
| def group_all_in_folder(dir_folder, override=False, clear=True, backup=True): | |||
| def group_all_in_folder(dir_folder, overwrite=False, clear=True, backup=True): | |||
| # Create folders. | |||
| os.makedirs(dir_folder + 'groups/', exist_ok=True) | |||
| if backup: | |||
| os.makedirs(dir_folder + 'backups', exist_ok=True) | |||
| # Iterate all files. | |||
| cur_file_prefix = '' | |||
| for file in tqdm(sorted(os.listdir(dir_folder)), desc='Grouping', file=sys.stdout): | |||
| @@ -106,20 +152,23 @@ def group_all_in_folder(dir_folder, override=False, clear=True, backup=True): | |||
| # print(name) | |||
| # print(name_prefix) | |||
| if name_prefix != cur_file_prefix: | |||
| group_trials(dir_folder, name_prefix, override, clear, backup) | |||
| group_trials(dir_folder, name_prefix, overwrite, clear, backup) | |||
| cur_file_prefix = name_prefix | |||
| if __name__ == '__main__': | |||
| dir_folder = 'outputs/CRIANN/edit_costs.num_sols.ratios.IPFP/' | |||
| group_all_in_folder(dir_folder) | |||
| dir_folder = 'outputs/CRIANN/edit_costs.repeats.ratios.IPFP/' | |||
| group_all_in_folder(dir_folder) | |||
| dir_folder = 'outputs/CRIANN/edit_costs.max_num_sols.ratios.bipartite/' | |||
| group_all_in_folder(dir_folder) | |||
| dir_folder = 'outputs/CRIANN/edit_costs.repeats.ratios.bipartite/' | |||
| group_all_in_folder(dir_folder) | |||
| # dir_folder = 'outputs/CRIANN/edit_costs.num_sols.ratios.IPFP/' | |||
| # group_all_in_folder(dir_folder) | |||
| # dir_folder = 'outputs/CRIANN/edit_costs.repeats.ratios.IPFP/' | |||
| # group_all_in_folder(dir_folder) | |||
| # dir_folder = 'outputs/CRIANN/edit_costs.max_num_sols.ratios.bipartite/' | |||
| # group_all_in_folder(dir_folder) | |||
| # dir_folder = 'outputs/CRIANN/edit_costs.repeats.ratios.bipartite/' | |||
| # group_all_in_folder(dir_folder) | |||
| dir_folder = 'outputs/edit_costs.real_data.num_sols.ratios.IPFP/groups/' | |||
| create_group_marker_file(dir_folder) | |||
| @@ -15,30 +15,30 @@ def get_job_script(arg): | |||
| #SBATCH --exclusive | |||
| #SBATCH --job-name="st.""" + arg + r""".IPFP" | |||
| #SBATCH --partition=tlong | |||
| #SBATCH --partition=court | |||
| #SBATCH --mail-type=ALL | |||
| #SBATCH --mail-user=jajupmochi@gmail.com | |||
| #SBATCH --output="outputs/output_edit_costs.nums_sols.ratios.IPFP.""" + arg + """.txt" | |||
| #SBATCH --error="errors/error_edit_costs.nums_sols.ratios.IPFP.""" + arg + """.txt" | |||
| #SBATCH --output="outputs/output_edit_costs.real_data.nums_sols.ratios.IPFP.""" + arg + """.txt" | |||
| #SBATCH --error="errors/error_edit_costs.real_data.nums_sols.ratios.IPFP.""" + arg + """.txt" | |||
| # | |||
| #SBATCH --ntasks=1 | |||
| #SBATCH --nodes=1 | |||
| #SBATCH --cpus-per-task=1 | |||
| #SBATCH --time=300:00:00 | |||
| #SBATCH --time=48:00:00 | |||
| #SBATCH --mem-per-cpu=4000 | |||
| srun hostname | |||
| srun cd /home/2019015/ljia02/graphkit-learn/gklearn/experiments/ged/stability | |||
| srun python3 edit_costs.nums_sols.ratios.IPFP.py """ + arg | |||
| srun python3 edit_costs.real_data.nums_sols.ratios.IPFP.py """ + arg | |||
| script = script.strip() | |||
| script = re.sub('\n\t+', '\n', script) | |||
| script = re.sub('\n +', '\n', script) | |||
| return script | |||
| if __name__ == '__main__': | |||
| ds_list = ['MAO', 'Monoterpenoides', 'MUTAG', 'AIDS_symb'] | |||
| for ds_name in [ds_list[i] for i in [0, 3]]: | |||
| ds_list = ['Acyclic', 'Alkane_unlabeled', 'MAO_lite', 'Monoterpenoides', 'MUTAG'] | |||
| for ds_name in [ds_list[i] for i in [0, 1, 2, 3, 4]]: | |||
| job_script = get_job_script(ds_name) | |||
| command = 'sbatch <<EOF\n' + job_script + '\nEOF' | |||
| # print(command) | |||
| @@ -5,26 +5,251 @@ Created on Thu Oct 29 19:17:36 2020 | |||
| @author: ljia | |||
| """ | |||
| from gklearn.utils import Dataset | |||
| import os | |||
| import pickle | |||
| import numpy as np | |||
| from tqdm import tqdm | |||
| import sys | |||
| from gklearn.dataset import Dataset | |||
| from gklearn.experiments import DATASET_ROOT | |||
| def get_dataset(ds_name): | |||
| # The node/edge labels that will not be used in the computation. | |||
| if ds_name == 'MAO': | |||
| irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']} | |||
| elif ds_name == 'Monoterpenoides': | |||
| irrelevant_labels = {'edge_labels': ['valence']} | |||
| elif ds_name == 'MUTAG': | |||
| irrelevant_labels = {'edge_labels': ['label_0']} | |||
| elif ds_name == 'AIDS_symb': | |||
| # if ds_name == 'MAO': | |||
| # irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']} | |||
| # if ds_name == 'Monoterpenoides': | |||
| # irrelevant_labels = {'edge_labels': ['valence']} | |||
| # elif ds_name == 'MUTAG': | |||
| # irrelevant_labels = {'edge_labels': ['label_0']} | |||
| if ds_name == 'AIDS_symb': | |||
| irrelevant_labels = {'node_attrs': ['chem', 'charge', 'x', 'y'], 'edge_labels': ['valence']} | |||
| ds_name = 'AIDS' | |||
| else: | |||
| irrelevant_labels = {} | |||
| # Initialize a Dataset. | |||
| dataset = Dataset() | |||
| # Load predefined dataset. | |||
| dataset.load_predefined_dataset(ds_name) | |||
| dataset = Dataset(ds_name, root=DATASET_ROOT) | |||
| # Remove irrelevant labels. | |||
| dataset.remove_labels(**irrelevant_labels) | |||
| print('dataset size:', len(dataset.graphs)) | |||
| return dataset | |||
| return dataset | |||
| def set_edit_cost_consts(ratio, node_labeled=True, edge_labeled=True, mode='uniform'): | |||
| if mode == 'uniform': | |||
| edit_cost_constants = [i * ratio for i in [1, 1, 1]] + [1, 1, 1] | |||
| if not node_labeled: | |||
| edit_cost_constants[2] = 0 | |||
| if not edge_labeled: | |||
| edit_cost_constants[5] = 0 | |||
| return edit_cost_constants | |||
| def nested_keys_exists(element, *keys): | |||
| ''' | |||
| Check if *keys (nested) exists in `element` (dict). | |||
| ''' | |||
| if not isinstance(element, dict): | |||
| raise AttributeError('keys_exists() expects dict as first argument.') | |||
| if len(keys) == 0: | |||
| raise AttributeError('keys_exists() expects at least two arguments, one given.') | |||
| _element = element | |||
| for key in keys: | |||
| try: | |||
| _element = _element[key] | |||
| except KeyError: | |||
| return False | |||
| return True | |||
| # Check average relative error along elements in two ged matrices. | |||
| def matrices_ave_relative_error(m1, m2): | |||
| error = 0 | |||
| base = 0 | |||
| for i in range(m1.shape[0]): | |||
| for j in range(m1.shape[1]): | |||
| error += np.abs(m1[i, j] - m2[i, j]) | |||
| base += (np.abs(m1[i, j]) + np.abs(m2[i, j])) / 2 | |||
| return error / base | |||
| def compute_relative_error(ged_mats): | |||
| if len(ged_mats) != 0: | |||
| # get the smallest "correct" GED matrix. | |||
| ged_mat_s = np.ones(ged_mats[0].shape) * np.inf | |||
| for i in range(ged_mats[0].shape[0]): | |||
| for j in range(ged_mats[0].shape[1]): | |||
| ged_mat_s[i, j] = np.min([mat[i, j] for mat in ged_mats]) | |||
| # compute average error. | |||
| errors = [] | |||
| for i, mat in enumerate(ged_mats): | |||
| err = matrices_ave_relative_error(mat, ged_mat_s) | |||
| # if not per_correct: | |||
| # print('matrix # ', str(i)) | |||
| # pass | |||
| errors.append(err) | |||
| else: | |||
| errors = [0] | |||
| return np.mean(errors) | |||
| def parse_group_file_name(fn): | |||
| splits_all = fn.split('.') | |||
| key1 = splits_all[1] | |||
| pos2 = splits_all[2].rfind('_') | |||
| # key2 = splits_all[2][:pos2] | |||
| val2 = splits_all[2][pos2+1:] | |||
| pos3 = splits_all[3].rfind('_') | |||
| # key3 = splits_all[3][:pos3] | |||
| val3 = splits_all[3][pos3+1:] + '.' + splits_all[4] | |||
| return key1, val2, val3 | |||
| def get_all_errors(save_dir, errors): | |||
| # Loop for each GED matrix file. | |||
| for file in tqdm(sorted(os.listdir(save_dir)), desc='Getting errors', file=sys.stdout): | |||
| if os.path.isfile(os.path.join(save_dir, file)) and file.startswith('ged_mats.'): | |||
| keys = parse_group_file_name(file) | |||
| # Check if the results is in the errors. | |||
| if not keys[0] in errors: | |||
| errors[keys[0]] = {} | |||
| if not keys[1] in errors[keys[0]]: | |||
| errors[keys[0]][keys[1]] = {} | |||
| # Compute the error if not exist. | |||
| if not keys[2] in errors[keys[0]][keys[1]]: | |||
| ged_mats = np.load(os.path.join(save_dir, file)) | |||
| errors[keys[0]][keys[1]][keys[2]] = compute_relative_error(ged_mats) | |||
| return errors | |||
| def get_relative_errors(save_dir, overwrite=False): | |||
| """ # Read relative errors from previous computed and saved file. Create the | |||
| file, compute the errors, or add and save the new computed errors to the | |||
| file if necessary. | |||
| Parameters | |||
| ---------- | |||
| save_dir : TYPE | |||
| DESCRIPTION. | |||
| overwrite : TYPE, optional | |||
| DESCRIPTION. The default is False. | |||
| Returns | |||
| ------- | |||
| None. | |||
| """ | |||
| if not overwrite: | |||
| fn_err = save_dir + '/relative_errors.pkl' | |||
| # If error file exists. | |||
| if os.path.isfile(fn_err): | |||
| with open(fn_err, 'rb') as f: | |||
| errors = pickle.load(f) | |||
| errors = get_all_errors(save_dir, errors) | |||
| else: | |||
| errors = get_all_errors(save_dir, {}) | |||
| else: | |||
| errors = get_all_errors(save_dir, {}) | |||
| with open(fn_err, 'wb') as f: | |||
| pickle.dump(errors, f) | |||
| return errors | |||
| def interpolate_result(Z, method='linear'): | |||
| values = Z.copy() | |||
| for i in range(Z.shape[0]): | |||
| for j in range(Z.shape[1]): | |||
| if np.isnan(Z[i, j]): | |||
| # Get the nearest non-nan values. | |||
| x_neg = np.nan | |||
| for idx, val in enumerate(Z[i, :][j::-1]): | |||
| if not np.isnan(val): | |||
| x_neg = val | |||
| x_neg_off = idx | |||
| break | |||
| x_pos = np.nan | |||
| for idx, val in enumerate(Z[i, :][j:]): | |||
| if not np.isnan(val): | |||
| x_pos = val | |||
| x_pos_off = idx | |||
| break | |||
| # Interpolate. | |||
| if not np.isnan(x_neg) and not np.isnan(x_pos): | |||
| val_int = (x_pos_off / (x_neg_off + x_pos_off)) * (x_neg - x_pos) + x_pos | |||
| values[i, j] = val_int | |||
| break | |||
| y_neg = np.nan | |||
| for idx, val in enumerate(Z[:, j][i::-1]): | |||
| if not np.isnan(val): | |||
| y_neg = val | |||
| y_neg_off = idx | |||
| break | |||
| y_pos = np.nan | |||
| for idx, val in enumerate(Z[:, j][i:]): | |||
| if not np.isnan(val): | |||
| y_pos = val | |||
| y_pos_off = idx | |||
| break | |||
| # Interpolate. | |||
| if not np.isnan(y_neg) and not np.isnan(y_pos): | |||
| val_int = (y_pos_off / (y_neg_off + y_neg_off)) * (y_neg - y_pos) + y_pos | |||
| values[i, j] = val_int | |||
| break | |||
| return values | |||
| def set_axis_style(ax): | |||
| ax.set_axisbelow(True) | |||
| ax.spines['top'].set_visible(False) | |||
| ax.spines['bottom'].set_visible(False) | |||
| ax.spines['right'].set_visible(False) | |||
| ax.spines['left'].set_visible(False) | |||
| ax.xaxis.set_ticks_position('none') | |||
| ax.yaxis.set_ticks_position('none') | |||
| ax.tick_params(labelsize=8, color='w', pad=1, grid_color='w') | |||
| ax.tick_params(axis='x', pad=-2) | |||
| ax.tick_params(axis='y', labelrotation=-40, pad=-2) | |||
| # ax.zaxis._axinfo['juggled'] = (1, 2, 0) | |||
| ax.set_xlabel(ax.get_xlabel(), fontsize=10, labelpad=-3) | |||
| ax.set_ylabel(ax.get_ylabel(), fontsize=10, labelpad=-2, rotation=50) | |||
| ax.set_zlabel(ax.get_zlabel(), fontsize=10, labelpad=-2) | |||
| ax.set_title(ax.get_title(), pad=30, fontsize=15) | |||
| return | |||
| if __name__ == '__main__': | |||
| root_dir = 'outputs/CRIANN/' | |||
| # for dir_ in sorted(os.listdir(root_dir)): | |||
| # if os.path.isdir(root_dir): | |||
| # full_dir = os.path.join(root_dir, dir_) | |||
| # print('---', full_dir,':') | |||
| # save_dir = os.path.join(full_dir, 'groups/') | |||
| # if os.path.exists(save_dir): | |||
| # try: | |||
| # get_relative_errors(save_dir) | |||
| # except Exception as exp: | |||
| # print('An exception occured when running this experiment:') | |||
| # print(repr(exp)) | |||