|
|
|
@@ -27,7 +27,6 @@ from matplotlib import animation |
|
|
|
from sklearn import manifold |
|
|
|
from mindspore import context |
|
|
|
from mindspore import Tensor |
|
|
|
from mindspore.common import set_seed |
|
|
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint |
|
|
|
|
|
|
|
from src.gcn import GCN |
|
|
|
@@ -51,7 +50,6 @@ def train(): |
|
|
|
"""Train model.""" |
|
|
|
parser = argparse.ArgumentParser(description='GCN') |
|
|
|
parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory') |
|
|
|
parser.add_argument('--seed', type=int, default=0, help='Random seed') |
|
|
|
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training') |
|
|
|
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation') |
|
|
|
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test') |
|
|
|
@@ -60,7 +58,6 @@ def train(): |
|
|
|
if not os.path.exists("ckpts"): |
|
|
|
os.mkdir("ckpts") |
|
|
|
|
|
|
|
set_seed(args_opt.seed) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, |
|
|
|
device_target="Ascend", save_graphs=False) |
|
|
|
config = ConfigGCN() |
|
|
|
|