| @@ -77,8 +77,6 @@ class GraphKernel(BaseEstimator): #, ABC): | |||||
| # Clear any prior attributes stored on the estimator, # @todo: unless warm_start is used; | # Clear any prior attributes stored on the estimator, # @todo: unless warm_start is used; | ||||
| self.clear_attributes() | self.clear_attributes() | ||||
| # X = check_array(X, accept_sparse=True) | |||||
| # Validate parameters for the transformer. | # Validate parameters for the transformer. | ||||
| self.validate_parameters() | self.validate_parameters() | ||||
| @@ -386,35 +384,58 @@ class GraphKernel(BaseEstimator): #, ABC): | |||||
| self.n_jobs = kwargs.get('n_jobs', multiprocessing.cpu_count()) | self.n_jobs = kwargs.get('n_jobs', multiprocessing.cpu_count()) | ||||
| self.normalize = kwargs.get('normalize', True) | self.normalize = kwargs.get('normalize', True) | ||||
| self.verbose = kwargs.get('verbose', 2) | self.verbose = kwargs.get('verbose', 2) | ||||
| self.copy_graphs = kwargs.get('copy_graphs', True) | |||||
| self.save_unnormed = kwargs.get('save_unnormed', True) | |||||
| self.validate_parameters() | self.validate_parameters() | ||||
| # If the inputs is a list of graphs. | |||||
| if len(graphs) == 1: | if len(graphs) == 1: | ||||
| if not isinstance(graphs[0], list): | if not isinstance(graphs[0], list): | ||||
| raise Exception('Cannot detect graphs.') | raise Exception('Cannot detect graphs.') | ||||
| elif len(graphs[0]) == 0: | elif len(graphs[0]) == 0: | ||||
| raise Exception('The graph list given is empty. No computation was performed.') | raise Exception('The graph list given is empty. No computation was performed.') | ||||
| else: | else: | ||||
| self._graphs = [g.copy() for g in graphs[0]] # @todo: might be very slow. | |||||
| if self.copy_graphs: | |||||
| self._graphs = [g.copy() for g in graphs[0]] # @todo: might be very slow. | |||||
| else: | |||||
| self._graphs = graphs | |||||
| self._gram_matrix = self._compute_gram_matrix() | self._gram_matrix = self._compute_gram_matrix() | ||||
| self._gram_matrix_unnorm = np.copy(self._gram_matrix) | |||||
| if self.save_unnormed: | |||||
| self._gram_matrix_unnorm = np.copy(self._gram_matrix) | |||||
| if self.normalize: | if self.normalize: | ||||
| self._gram_matrix = normalize_gram_matrix(self._gram_matrix) | self._gram_matrix = normalize_gram_matrix(self._gram_matrix) | ||||
| return self._gram_matrix, self._run_time | return self._gram_matrix, self._run_time | ||||
| elif len(graphs) == 2: | elif len(graphs) == 2: | ||||
| # If the inputs are two graphs. | |||||
| if self.is_graph(graphs[0]) and self.is_graph(graphs[1]): | if self.is_graph(graphs[0]) and self.is_graph(graphs[1]): | ||||
| kernel = self._compute_single_kernel(graphs[0].copy(), graphs[1].copy()) | |||||
| if self.copy_graphs: | |||||
| G0, G1 = graphs[0].copy(), graphs[1].copy() | |||||
| else: | |||||
| G0, G1 = graphs[0], graphs[1] | |||||
| kernel = self._compute_single_kernel(G0, G1) | |||||
| return kernel, self._run_time | return kernel, self._run_time | ||||
| # If the inputs are a graph and a list of graphs. | |||||
| elif self.is_graph(graphs[0]) and isinstance(graphs[1], list): | elif self.is_graph(graphs[0]) and isinstance(graphs[1], list): | ||||
| g1 = graphs[0].copy() | |||||
| g_list = [g.copy() for g in graphs[1]] | |||||
| kernel_list = self._compute_kernel_list(g1, g_list) | |||||
| if self.copy_graphs: | |||||
| g1 = graphs[0].copy() | |||||
| g_list = [g.copy() for g in graphs[1]] | |||||
| kernel_list = self._compute_kernel_list(g1, g_list) | |||||
| else: | |||||
| kernel_list = self._compute_kernel_list(graphs[0], graphs[1]) | |||||
| return kernel_list, self._run_time | return kernel_list, self._run_time | ||||
| elif isinstance(graphs[0], list) and self.is_graph(graphs[1]): | elif isinstance(graphs[0], list) and self.is_graph(graphs[1]): | ||||
| g1 = graphs[1].copy() | |||||
| g_list = [g.copy() for g in graphs[0]] | |||||
| kernel_list = self._compute_kernel_list(g1, g_list) | |||||
| if self.copy_graphs: | |||||
| g1 = graphs[1].copy() | |||||
| g_list = [g.copy() for g in graphs[0]] | |||||
| kernel_list = self._compute_kernel_list(g1, g_list) | |||||
| else: | |||||
| kernel_list = self._compute_kernel_list(graphs[1], graphs[0]) | |||||
| return kernel_list, self._run_time | return kernel_list, self._run_time | ||||
| else: | else: | ||||
| raise Exception('Cannot detect graphs.') | raise Exception('Cannot detect graphs.') | ||||