|
|
|
@@ -19,6 +19,7 @@ GCN training script. |
|
|
|
|
|
|
|
import time |
|
|
|
import argparse |
|
|
|
import ast |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
from matplotlib import pyplot as plt |
|
|
|
@@ -51,7 +52,7 @@ def train(): |
|
|
|
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training') |
|
|
|
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation') |
|
|
|
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test') |
|
|
|
parser.add_argument('--save_TSNE', type=bool, default=False, help='Whether to save t-SNE graph') |
|
|
|
parser.add_argument('--save_TSNE', type=ast.literal_eval, default=False, help='Whether to save t-SNE graph') |
|
|
|
args_opt = parser.parse_args() |
|
|
|
|
|
|
|
np.random.seed(args_opt.seed) |
|
|
|
|