| @@ -40,7 +40,10 @@ class _OGBNDatasetUtil(_OGBDatasetUtil): | |||
| edge_feat = torch.tensor(edge_feat) | |||
| edge_index = SparseTensor(row=torch.tensor(edge_index[0]), col=torch.tensor(edge_index[1]), value=edge_feat, sparse_sizes=(num_nodes, num_nodes)) | |||
| _, _, value = edge_index.coo() | |||
| ogbn_data['edge_feat'] = value.cpu().detach().numpy() | |||
| if value is not None: | |||
| ogbn_data['edge_feat'] = value.cpu().detach().numpy() | |||
| else: | |||
| ogbn_data['edge_feat'] = edge_feat | |||
| edge_index = edge_index.to_symmetric() | |||
| row, col, _ = edge_index.coo() | |||
| edge_index = np.array([row.cpu().detach().numpy(), col.cpu().detach().numpy()]) | |||
| @@ -18,7 +18,7 @@ from torch.optim.lr_scheduler import ( | |||
| ReduceLROnPlateau, | |||
| ) | |||
| from dig.sslgraph.method.contrastive.objectives import NCE_loss, JSE_loss | |||
| from .losses import NTXent_loss | |||
| from .utils import get_view_by_name | |||
| from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer | |||
| @@ -49,7 +49,7 @@ class BaseContrastiveTrainer(BaseTrainer): | |||
| feval: _typing.Union[ | |||
| _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = (Acc,), | |||
| loss: Union[str, Callable] = "NCE", | |||
| loss: Union[str, Callable] = "NT_Xent", | |||
| f_loss: Union[str, Callable] = "nll_loss", | |||
| views_fn: _typing.Union[ | |||
| _typing.Sequence[_typing.Callable], None | |||
| @@ -59,7 +59,6 @@ class BaseContrastiveTrainer(BaseTrainer): | |||
| node_level: bool = False, | |||
| z_dim: _typing.Union[int, None] = None, | |||
| z_node_dim: _typing.Union[int, None] = None, | |||
| neg_by_crpt: bool = False, | |||
| tau: int = 0.5, | |||
| p_optim: Union[torch.optim.Optimizer, str] = "Adam", | |||
| p_lr: float = 0.0001, | |||
| @@ -110,16 +109,13 @@ class BaseContrastiveTrainer(BaseTrainer): | |||
| The dimension of graph-level representations | |||
| z_node_dim: `int`, Optional | |||
| The dimension of node-level representations | |||
| neg_by_crpt: `bool`, Optional | |||
| The mode to obtain negative samples | |||
| tau: `int`, Optional | |||
| The temperature parameter in InfoNCE loss. Only used when `loss` = "NCE" | |||
| The temperature parameter in NT_Xent loss. Only used when `loss` = "NT_Xent" | |||
| model_path: `str` or None, Optional | |||
| The directory to restore the saved model. | |||
| If `model_path` = None, the model will not be saved. | |||
| """ | |||
| assert (node_level or graph_level) is True | |||
| assert not (loss == "NCE" and neg_by_crpt) | |||
| assert isinstance(encoder, BaseEncoderMaintainer) or isinstance(encoder, str) or encoder is None | |||
| self.loss = self._get_loss(loss) | |||
| self.node_level = node_level | |||
| @@ -141,7 +137,6 @@ class BaseContrastiveTrainer(BaseTrainer): | |||
| self.last_dim = z_dim if graph_level else z_node_dim | |||
| self.num_features = num_features | |||
| self.num_graph_features = num_graph_features | |||
| self.neg_by_crpt = neg_by_crpt | |||
| self.tau = tau | |||
| self.model_path = model_path | |||
| if isinstance(device, str): | |||
| @@ -195,8 +190,8 @@ class BaseContrastiveTrainer(BaseTrainer): | |||
| if callable(loss): | |||
| return loss | |||
| elif isinstance(loss, str): | |||
| assert loss in ['JSE', 'NCE'] | |||
| return {'JSE': JSE_loss, 'NCE': NCE_loss}[loss] | |||
| assert loss in ['NT_Xent'] | |||
| return {'NT_Xent': NTXent_loss}[loss] | |||
| else: | |||
| raise NotImplementedError("The argument `loss` should be str or callable which returns a loss tensor") | |||
| @@ -453,7 +448,7 @@ class BaseContrastiveTrainer(BaseTrainer): | |||
| for view in views: | |||
| z = self._get_embed(view.to(self.device)) | |||
| zs.append(self.decoder.decoder(z, view.to(self.device))) | |||
| loss = self.loss(zs, neg_by_crpt=self.neg_by_crpt, tau=self.tau) | |||
| loss = self.loss(zs, tau=self.tau) | |||
| loss.backward() | |||
| optimizer.step() | |||
| if self.p_lr_scheduler_type: | |||
| @@ -474,7 +469,7 @@ class BaseContrastiveTrainer(BaseTrainer): | |||
| for view in views: | |||
| z = self._get_embed(view.to(self.device)) | |||
| zs.append(self.decoder.decoder(z, view.to(self.device))) | |||
| loss = self.loss(zs, neg_by_crpt=self.neg_by_crpt, tau=self.tau) | |||
| loss = self.loss(zs, tau=self.tau) | |||
| epoch_loss += loss.item() | |||
| last_loss = loss.item() | |||
| return epoch_loss, last_loss | |||
| @@ -534,19 +529,7 @@ class BaseContrastiveTrainer(BaseTrainer): | |||
| return self.encoder.encoder.to(self.device) | |||
| def _get_embed(self, view): | |||
| if self.neg_by_crpt: | |||
| view_crpt = self._corrupt_graph(view) | |||
| if self.node_level and self.graph_level: | |||
| z_g, z_n = self.encoder.encoder(view) | |||
| z_g_crpt, z_n_crpt = self.encoder.encoder(view_crpt) | |||
| z = (torch.cat([z_g, z_g_crpt], 0), | |||
| torch.cat([z_n, z_n_crpt], 0)) | |||
| else: | |||
| z = self.encoder.encoder(view) | |||
| z_crpt = self.encoder.encoder(view_crpt) | |||
| z = torch.cat([z, z_crpt], 0) | |||
| else: | |||
| z = self.encoder.encoder(view) | |||
| z = self.encoder.encoder(view) | |||
| return z | |||
| def predict(self, dataset, mask="test"): | |||
| @@ -1,4 +1,3 @@ | |||
| # codes in this file are reproduced from <https://github.com/divelab/DIG> with some changes. | |||
| import os | |||
| import torch | |||
| import logging | |||
| @@ -14,7 +13,6 @@ from typing import Union, Tuple, Sequence, Type, Callable | |||
| from tqdm import trange | |||
| from copy import deepcopy | |||
| from dig.sslgraph.evaluation.eval_graph import k_fold | |||
| from .base import BaseContrastiveTrainer | |||
| @@ -54,7 +52,6 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer): | |||
| ] = None, | |||
| aug_ratio: Union[float, Sequence[float]] = 0.2, | |||
| z_dim: Union[int, None] = 128, | |||
| neg_by_crpt: bool = False, | |||
| tau: int = 0.5, | |||
| model_path: Union[str, None] = "./models", | |||
| num_workers: int = 0, | |||
| @@ -105,10 +102,8 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer): | |||
| If aug_ratio is set as a list of float, the value of this list and views_fn one to one correspondence. | |||
| z_dim: `int` | |||
| The dimension of graph-level representations | |||
| neg_by_crpt: `bool` | |||
| The mode to obtain negative samples. Only required when `loss` = "JSE" | |||
| tau: `int` | |||
| The temperature parameter in InfoNCE loss. Only used when `loss` = "NCE" | |||
| The temperature parameter in NT_Xent loss. Only used when `loss` = "NT_Xent" | |||
| model_path: `str` or None | |||
| The directory to restore the saved model. | |||
| If `model_path` = None, the model will not be saved. | |||
| @@ -165,9 +160,10 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer): | |||
| feval=feval, | |||
| z_dim=z_dim, | |||
| z_node_dim=None, | |||
| neg_by_crpt=neg_by_crpt, | |||
| tau=tau, | |||
| model_path=model_path | |||
| model_path=model_path, | |||
| *args, | |||
| **kwargs | |||
| ) | |||
| self.views_fn = views_fn | |||
| self.aug_ratio = aug_ratio | |||
| @@ -438,7 +434,6 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer): | |||
| views_fn=self.views_fn_opt, | |||
| aug_ratio=self.aug_ratio, | |||
| z_dim=self.last_dim, | |||
| neg_by_crpt=self.neg_by_crpt, | |||
| tau=self.tau, | |||
| model_path=self.model_path, | |||
| num_workers=self.num_workers, | |||
| @@ -530,10 +525,8 @@ class GraphCLUnsupervisedTrainer(BaseContrastiveTrainer): | |||
| If aug_ratio is set as a list of float, the value of this list and views_fn one to one correspondence. | |||
| z_dim: `int` | |||
| The dimension of graph-level representations | |||
| neg_by_crpt: `bool` | |||
| The mode to obtain negative samples. Only required when `loss` = "JSE" | |||
| tau: `int` | |||
| The temperature parameter in InfoNCE loss. Only used when `loss` = "NCE" | |||
| The temperature parameter in NT_Xent loss. Only used when `loss` = "NT_Xent" | |||
| model_path: `str` or None | |||
| The directory to restore the saved model. | |||
| If `model_path` = None, the model will not be saved. | |||
| @@ -894,7 +887,6 @@ class GraphCLUnsupervisedTrainer(BaseContrastiveTrainer): | |||
| views_fn=self.views_fn_opt, | |||
| aug_ratio=self.aug_ratio, | |||
| z_dim=self.last_dim, | |||
| neg_by_crpt=self.neg_by_crpt, | |||
| tau=self.tau, | |||
| model_path=self.model_path, | |||
| num_workers=self.num_workers, | |||
| @@ -0,0 +1,17 @@ | |||
| # NTXent_loss from <https://github.com/Shen-Lab/GraphCL/> | |||
| import torch | |||
| import torch.nn as nn | |||
| def NTXent_loss(zs, tau=0.5, norm=True): | |||
| batch_size, _ = zs[0].size() | |||
| sim_matrix = torch.einsum('ik,jk->ij', zs[0], zs[1]) | |||
| if norm: | |||
| z1_abs = zs[0].norm(dim=1) | |||
| z2_abs = zs[1].norm(dim=1) | |||
| sim_matrix = sim_matrix / torch.einsum('i,j->ij', z1_abs, z2_abs) | |||
| sim_matrix = torch.exp(sim_matrix/tau) | |||
| pos_sim = sim_matrix[range(batch_size), range(batch_size)] | |||
| loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) | |||
| loss = - torch.log(loss).mean() | |||
| return loss | |||
| @@ -1,8 +1,8 @@ | |||
| from dig.sslgraph.method.contrastive.views_fn import ( | |||
| NodeAttrMask, | |||
| EdgePerturbation, | |||
| UniformSample, | |||
| RWSample, | |||
| from .views_fn import ( | |||
| DropNode, | |||
| PermuteEdge, | |||
| MaskNode, | |||
| SubGraph, | |||
| RandomView | |||
| ) | |||
| @@ -10,28 +10,28 @@ def get_view_by_name(view, aug_ratio): | |||
| if view is None: | |||
| return lambda x: x | |||
| elif view == "dropN": | |||
| return UniformSample(ratio=aug_ratio) | |||
| return DropNode(aug_ratio=aug_ratio) | |||
| elif view == "permE": | |||
| return EdgePerturbation(ratio=aug_ratio) | |||
| return PermuteEdge(aug_ratio=aug_ratio) | |||
| elif view == "subgraph": | |||
| return RWSample(ratio=aug_ratio) | |||
| return SubGraph(aug_ratio=aug_ratio) | |||
| elif view == "maskN": | |||
| return NodeAttrMask(mask_ratio=aug_ratio) | |||
| return MaskNode(aug_ratio=aug_ratio) | |||
| elif view == "random2": | |||
| canditates = [UniformSample(ratio=aug_ratio), | |||
| RWSample(ratio=aug_ratio)] | |||
| canditates = [DropNode(aug_ratio=aug_ratio), | |||
| SubGraph(aug_ratio=aug_ratio)] | |||
| return RandomView(candidates=canditates) | |||
| elif view == "random3": | |||
| canditates = [UniformSample(ratio=aug_ratio), | |||
| RWSample(ratio=aug_ratio), | |||
| EdgePerturbation(ratio=aug_ratio)] | |||
| canditates = [DropNode(aug_ratio=aug_ratio), | |||
| SubGraph(aug_ratio=aug_ratio), | |||
| PermuteEdge(aug_ratio=aug_ratio)] | |||
| return RandomView(candidates=canditates) | |||
| elif view == "random4": | |||
| canditates = [UniformSample(ratio=aug_ratio), | |||
| RWSample(ratio=aug_ratio), | |||
| EdgePerturbation(ratio=aug_ratio), | |||
| NodeAttrMask(mask_ratio=aug_ratio)] | |||
| canditates = [DropNode(aug_ratio=aug_ratio), | |||
| SubGraph(aug_ratio=aug_ratio), | |||
| PermuteEdge(aug_ratio=aug_ratio), | |||
| MaskNode(aug_ratio=aug_ratio)] | |||
| return RandomView(candidates=canditates) | |||
| else: | |||
| raise NotImplementedError(f'The augmentation method must be in ["dropN", "permE", "subgraph", \ | |||
| "maskN", "random2", "random3", "random4"] or None. And {view} is not supported yet.') | |||
| raise NotImplementedError(f'{view} is not supported yet. Support: ["dropN", "permE", "subgraph", \ | |||
| "maskN", "random2", "random3", "random4", None]') | |||
| @@ -0,0 +1,137 @@ | |||
| # pyg augmentation method from <https://github.com/Shen-Lab/GraphCL/> | |||
| import random | |||
| import torch | |||
| import numpy as np | |||
| from itertools import repeat, product | |||
| from torch_geometric.data import Batch | |||
| class BaseAugmentation: | |||
| def __init__(self, aug_ratio=None): | |||
| self.aug_ratio = aug_ratio | |||
| def _aug_data(self, data): | |||
| pass | |||
| def __call__(self, batch): | |||
| new_data = [] | |||
| for data in batch.to_data_list(): | |||
| new_data.append(self._aug_data(data)) | |||
| return Batch.from_data_list(new_data) | |||
| class DropNode(BaseAugmentation): | |||
| def __init__(self, aug_ratio): | |||
| super().__init__(aug_ratio) | |||
| def _aug_data(self, data): | |||
| node_num, _ = data.x.size() | |||
| _, edge_num = data.edge_index.size() | |||
| drop_num = int(node_num * self.aug_ratio) | |||
| idx_perm = np.random.permutation(node_num) | |||
| idx_drop = idx_perm[:drop_num] | |||
| idx_nondrop = idx_perm[drop_num:] | |||
| idx_nondrop.sort() | |||
| idx_dict = {idx_nondrop[n]:n for n in list(range(idx_nondrop.shape[0]))} | |||
| edge_index = data.edge_index.numpy() | |||
| adj = torch.zeros((node_num, node_num)) | |||
| adj[edge_index[0], edge_index[1]] = 1 | |||
| adj = adj[idx_nondrop, :][:, idx_nondrop] | |||
| edge_index = adj.nonzero().t() | |||
| try: | |||
| data.edge_index = edge_index | |||
| data.x = data.x[idx_nondrop] | |||
| except: | |||
| data = data | |||
| return data | |||
| class PermuteEdge(BaseAugmentation): | |||
| def __init__(self, aug_ratio): | |||
| super().__init__(aug_ratio) | |||
| def _aug_data(self, data): | |||
| node_num, _ = data.x.size() | |||
| _, edge_num = data.edge_index.size() | |||
| permute_num = int(edge_num * self.aug_ratio) | |||
| edge_index = data.edge_index.numpy() | |||
| idx_add = np.random.choice(node_num, (2, permute_num)) | |||
| # idx_add = [[idx_add[0, n], idx_add[1, n]] for n in range(permute_num) if not (idx_add[0, n], idx_add[1, n]) in edge_index] | |||
| # edge_index = [edge_index[n] for n in range(edge_num) if not n in np.random.choice(edge_num, permute_num, replace=False)] + idx_add | |||
| edge_index = np.concatenate((edge_index[:, np.random.choice(edge_num, (edge_num - permute_num), replace=False)], idx_add), axis=1) | |||
| data.edge_index = torch.tensor(edge_index) | |||
| return data | |||
| class SubGraph(BaseAugmentation): | |||
| def __init__(self, aug_ratio): | |||
| super().__init__(aug_ratio) | |||
| def _aug_data(self, data): | |||
| node_num, _ = data.x.size() | |||
| _, edge_num = data.edge_index.size() | |||
| sub_num = int(node_num * self.aug_ratio) | |||
| edge_index = data.edge_index.numpy() | |||
| idx_sub = [np.random.randint(node_num, size=1)[0]] | |||
| idx_neigh = set([n for n in edge_index[1][edge_index[0]==idx_sub[0]]]) | |||
| count = 0 | |||
| while len(idx_sub) <= sub_num: | |||
| count = count + 1 | |||
| if count > node_num: | |||
| break | |||
| if len(idx_neigh) == 0: | |||
| break | |||
| sample_node = np.random.choice(list(idx_neigh)) | |||
| if sample_node in idx_sub: | |||
| continue | |||
| idx_sub.append(sample_node) | |||
| idx_neigh.union(set([n for n in edge_index[1][edge_index[0]==idx_sub[-1]]])) | |||
| idx_drop = [n for n in range(node_num) if not n in idx_sub] | |||
| idx_nondrop = idx_sub | |||
| data.x = data.x[idx_nondrop] | |||
| idx_dict = {idx_nondrop[n]:n for n in list(range(len(idx_nondrop)))} | |||
| edge_index = data.edge_index.numpy() | |||
| adj = torch.zeros((node_num, node_num)) | |||
| adj[edge_index[0], edge_index[1]] = 1 | |||
| adj[list(range(node_num)), list(range(node_num))] = 1 | |||
| adj = adj[idx_nondrop, :][:, idx_nondrop] | |||
| edge_index = adj.nonzero().t() | |||
| # edge_index = [[idx_dict[edge_index[0, n]], idx_dict[edge_index[1, n]]] for n in range(edge_num) if (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)] + [[n, n] for n in idx_nondrop] | |||
| data.edge_index = edge_index | |||
| return data | |||
| class MaskNode(BaseAugmentation): | |||
| def __init__(self, aug_ratio): | |||
| super().__init__(aug_ratio) | |||
| def _aug_data(self, data): | |||
| node_num, feat_dim = data.x.size() | |||
| mask_num = int(node_num * self.aug_ratio) | |||
| token = data.x.mean(dim=0) | |||
| idx_mask = np.random.choice(node_num, mask_num, replace=False) | |||
| data.x[idx_mask] = torch.tensor(token, dtype=torch.float32) | |||
| return data | |||
| class RandomView(BaseAugmentation): | |||
| def __init__(self, candidates): | |||
| super().__init__() | |||
| self.candidates = candidates | |||
| def _aug_data(self, data): | |||
| view = random.choice(self.candidates) | |||
| return view._aug_data(data) | |||
| @@ -302,7 +302,7 @@ class SSLGraphClassifier(BaseClassifier): | |||
| num_classes=num_classes, | |||
| feval=evaluator_list, | |||
| device=self.runtime_device, | |||
| loss="NCE" if not hasattr(dataset, "loss") else dataset.loss, | |||
| loss="NT_Xent" if not hasattr(dataset, "loss") else dataset.loss, | |||
| num_graph_features=(0 | |||
| if not hasattr(dataset[0], "gf") | |||
| else dataset[0].gf.size(1)) if BACKEND == 'pyg' else | |||
| @@ -35,12 +35,12 @@ def test_graph_trainer(): | |||
| prediction_head="sumpoolmlp", | |||
| views_fn=["random2", "random2"], | |||
| batch_size=128, | |||
| p_lr=5.6004725115062315e-05, | |||
| p_weight_decay=0.00022810837622188083, | |||
| p_epoch=267, | |||
| f_epoch=131, | |||
| f_lr=0.0005362155524564354, | |||
| f_weight_decay=0.0022069814932058804, | |||
| p_lr=0.0001, | |||
| p_weight_decay=0.0002, | |||
| p_epoch=300, | |||
| f_epoch=150, | |||
| f_lr=0.0001, | |||
| f_weight_decay=0.002, | |||
| p_early_stopping_round=50, | |||
| f_early_stopping_round=50, | |||
| z_dim=128, | |||
| @@ -0,0 +1,115 @@ | |||
| import os | |||
| import random | |||
| import torch | |||
| import torch.nn as nn | |||
| import numpy as np | |||
| from autogl.module.train.ssl import GraphCLSemisupervisedTrainer | |||
| from autogl.datasets import build_dataset_from_name, utils | |||
| from autogl.datasets.utils.conversion import to_pyg_dataset as convert_dataset | |||
| from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer | |||
| from autogl.module.model.decoders import BaseDecoderMaintainer | |||
| from autogl.solver.utils import set_seed | |||
| def fixed(**kwargs): | |||
| return [{ | |||
| 'parameterName': k, | |||
| 'type': "FIXED", | |||
| 'value': v | |||
| } for k, v in kwargs.items()] | |||
| if __name__ == "__main__": | |||
| import argparse | |||
| parser = argparse.ArgumentParser('ssl pyg trainer') | |||
| parser.add_argument('--device', type=str, default='cuda') | |||
| parser.add_argument('--dataset', type=str, choices=['MUTAG', 'NCI1', 'PROTEINS', 'PTC_MR'], default='PROTEINS') | |||
| parser.add_argument('--dataset_seed', type=int, default=2021) | |||
| parser.add_argument('--batch_size', type=int, default=32) | |||
| parser.add_argument('--repeat', type=int, default=50) | |||
| # parser.add_argument('--model', type=str, choices=['gin', 'gat', 'gcn', 'sage'], default='gin') | |||
| parser.add_argument('--encoder', type=str, choices=['gin', 'gcn'], default='gcn') | |||
| parser.add_argument('--p_lr', type=float, default=0.0001) | |||
| parser.add_argument('--p_weight_decay', type=float, default=0) | |||
| parser.add_argument('--p_epoch', type=int, default=100) | |||
| parser.add_argument('--f_lr', type=float, default=0.001) | |||
| parser.add_argument('--f_weight_decay', type=float, default=0) | |||
| parser.add_argument('--f_epoch', type=int, default=100) | |||
| parser.add_argument('--epoch', type=int, default=100) | |||
| args=parser.parse_args() | |||
| # split dataset | |||
| dataset = build_dataset_from_name(args.dataset) | |||
| dataset = convert_dataset(dataset) | |||
| utils.graph_random_splits(dataset, train_ratio=0.1, val_ratio=0.1, seed=2022) | |||
| accs = [[],[],[]] | |||
| encoder_hp = { | |||
| "num_layers": 5, | |||
| "hidden": [32, 64, 64, 64], | |||
| "dropout": 0.5, | |||
| "act": "elu", | |||
| "eps": "true" | |||
| } | |||
| decoder_hp = { | |||
| "hidden": 32, | |||
| "act": "tanh", | |||
| "dropout": 0.35 | |||
| } | |||
| prediction_head = { | |||
| "hidden": 128, | |||
| "act": "relu", | |||
| "dropout": 0.4 | |||
| } | |||
| from tqdm import tqdm | |||
| for seed in tqdm(range(args.repeat)): | |||
| set_seed(seed) | |||
| trainer = GraphCLSemisupervisedTrainer( | |||
| model=(args.encoder, 'sumpoolmlp'), | |||
| prediction_head='sumpoolmlp', | |||
| views_fn=['random2', 'random2'], | |||
| device=args.device, | |||
| num_features=dataset[0].x.size(1), | |||
| num_classes=max([data.y.item() for data in dataset]) + 1, | |||
| batch_size=args.batch_size, | |||
| # p_lr=args.p_lr, | |||
| # p_weight_decay=args.p_weight_decay, | |||
| # p_epoch=args.p_epoch, | |||
| # f_lr=args.f_lr, | |||
| # f_weight_decay=args.f_weight_decay, | |||
| # f_epoch=args.f_epoch, | |||
| z_dim=128, | |||
| init=False | |||
| ) | |||
| trainer.initialize() | |||
| trainer = trainer.duplicate_from_hyper_parameter( | |||
| { | |||
| 'trainer': { | |||
| 'batch_size': args.batch_size, | |||
| 'p_lr': args.p_lr, | |||
| 'p_weight_decay': args.p_weight_decay, | |||
| 'p_epoch': args.p_epoch, | |||
| 'p_early_stopping_round': args.p_epoch + 1, | |||
| 'f_lr': args.f_lr, | |||
| 'f_weight_decay': args.f_weight_decay, | |||
| 'f_epoch': args.f_epoch, | |||
| 'f_early_stopping_round': args.f_epoch + 1, | |||
| }, | |||
| "encoder": encoder_hp, | |||
| "decoder": decoder_hp, | |||
| "prediction_head": prediction_head | |||
| } | |||
| ) | |||
| trainer.train(dataset, False) | |||
| out = trainer.predict(dataset, 'test').detach().cpu().numpy() | |||
| train_result = trainer.evaluate(dataset, 'train') | |||
| valid_result = trainer.evaluate(dataset, 'val') | |||
| test_result = trainer.evaluate(dataset, 'test') | |||
| print(f"{train_result[0]} - {valid_result[0]} - {test_result[0]}") | |||
| accs[0].append(train_result[0]) | |||
| accs[1].append(valid_result[0]) | |||
| accs[2].append(test_result[0]) | |||
| print('{:.4f} ~ {:.4f}'.format(np.mean(accs[0]), np.std(accs[0]))) | |||
| print('{:.4f} ~ {:.4f}'.format(np.mean(accs[1]), np.std(accs[1]))) | |||
| print('{:.4f} ~ {:.4f}'.format(np.mean(accs[2]), np.std(accs[2]))) | |||