From 9ee138fbcf1cc67852039cf6bc9cedb2dd62e874 Mon Sep 17 00:00:00 2001 From: zhanke Date: Thu, 19 Nov 2020 09:08:19 +0800 Subject: [PATCH] change gnn seed --- model_zoo/official/gnn/gat/train.py | 2 -- model_zoo/official/gnn/gcn/train.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/model_zoo/official/gnn/gat/train.py b/model_zoo/official/gnn/gat/train.py index 1e022cbd81..20fcb20ef7 100644 --- a/model_zoo/official/gnn/gat/train.py +++ b/model_zoo/official/gnn/gat/train.py @@ -19,7 +19,6 @@ import os import numpy as np import mindspore.context as context from mindspore.train.serialization import save_checkpoint, load_checkpoint -from mindspore.common import set_seed from mindspore import Tensor from src.config import GatConfig @@ -27,7 +26,6 @@ from src.dataset import load_and_process from src.gat import GAT from src.utils import LossAccuracyWrapper, TrainGAT -set_seed(0) def train(): """Train GAT model.""" diff --git a/model_zoo/official/gnn/gcn/train.py b/model_zoo/official/gnn/gcn/train.py index 2f01080581..706d4ee67e 100644 --- a/model_zoo/official/gnn/gcn/train.py +++ b/model_zoo/official/gnn/gcn/train.py @@ -27,7 +27,6 @@ from matplotlib import animation from sklearn import manifold from mindspore import context from mindspore import Tensor -from mindspore.common import set_seed from mindspore.train.serialization import save_checkpoint, load_checkpoint from src.gcn import GCN @@ -51,7 +50,6 @@ def train(): """Train model.""" parser = argparse.ArgumentParser(description='GCN') parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory') - parser.add_argument('--seed', type=int, default=0, help='Random seed') parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training') parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation') parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test') @@ -60,7 +58,6 @@ def train(): if not os.path.exists("ckpts"): os.mkdir("ckpts") - set_seed(args_opt.seed) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) config = ConfigGCN()