| @@ -0,0 +1,25 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """hub config.""" | |||||
| from src.bgcf import BGCF | |||||
| def bgcf(*args, **kwargs): | |||||
| return BGCF(*args, **kwargs) | |||||
| def create_network(name, *args, **kwargs): | |||||
| if name == "bgcf": | |||||
| return bgcf(*args, **kwargs) | |||||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||||
| @@ -0,0 +1,25 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """hub config.""" | |||||
| from src.gat import GAT | |||||
| def gat(*args, **kwargs): | |||||
| return GAT(*args, **kwargs) | |||||
| def create_network(name, *args, **kwargs): | |||||
| if name == "gat": | |||||
| return gat(*args, **kwargs) | |||||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||||
| @@ -0,0 +1,25 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """hub config.""" | |||||
| from src.gcn import GCN | |||||
| def gcn(*args, **kwargs): | |||||
| return GCN(*args, **kwargs) | |||||
| def create_network(name, *args, **kwargs): | |||||
| if name == "gcn": | |||||
| return gcn(*args, **kwargs) | |||||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||||
| @@ -25,7 +25,7 @@ from mindspore.ops import functional as F | |||||
| class Loss(nn.Cell): | class Loss(nn.Cell): | ||||
| """Softmax cross-entropy loss with masking.""" | """Softmax cross-entropy loss with masking.""" | ||||
| def __init__(self, label, mask, weight_decay, param): | def __init__(self, label, mask, weight_decay, param): | ||||
| super(Loss, self).__init__() | |||||
| super(Loss, self).__init__(auto_prefix=False) | |||||
| self.label = Tensor(label) | self.label = Tensor(label) | ||||
| self.mask = Tensor(mask) | self.mask = Tensor(mask) | ||||
| self.loss = P.SoftmaxCrossEntropyWithLogits() | self.loss = P.SoftmaxCrossEntropyWithLogits() | ||||
| @@ -55,7 +55,7 @@ class Loss(nn.Cell): | |||||
| class Accuracy(nn.Cell): | class Accuracy(nn.Cell): | ||||
| """Accuracy with masking.""" | """Accuracy with masking.""" | ||||
| def __init__(self, label, mask): | def __init__(self, label, mask): | ||||
| super(Accuracy, self).__init__() | |||||
| super(Accuracy, self).__init__(auto_prefix=False) | |||||
| self.label = Tensor(label) | self.label = Tensor(label) | ||||
| self.mask = Tensor(mask) | self.mask = Tensor(mask) | ||||
| self.equal = P.Equal() | self.equal = P.Equal() | ||||
| @@ -86,7 +86,7 @@ class LossAccuracyWrapper(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self, network, label, mask, weight_decay): | def __init__(self, network, label, mask, weight_decay): | ||||
| super(LossAccuracyWrapper, self).__init__() | |||||
| super(LossAccuracyWrapper, self).__init__(auto_prefix=False) | |||||
| self.network = network | self.network = network | ||||
| self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) | self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) | ||||
| self.accuracy = Accuracy(label, mask) | self.accuracy = Accuracy(label, mask) | ||||
| @@ -110,7 +110,7 @@ class LossWrapper(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self, network, label, mask, weight_decay): | def __init__(self, network, label, mask, weight_decay): | ||||
| super(LossWrapper, self).__init__() | |||||
| super(LossWrapper, self).__init__(auto_prefix=False) | |||||
| self.network = network | self.network = network | ||||
| self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) | self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) | ||||
| @@ -174,7 +174,7 @@ class TrainNetWrapper(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self, network, label, mask, config): | def __init__(self, network, label, mask, config): | ||||
| super(TrainNetWrapper, self).__init__(auto_prefix=True) | |||||
| super(TrainNetWrapper, self).__init__(auto_prefix=False) | |||||
| self.network = network | self.network = network | ||||
| loss_net = LossWrapper(network, label, mask, config.weight_decay) | loss_net = LossWrapper(network, label, mask, config.weight_decay) | ||||
| optimizer = nn.Adam(loss_net.trainable_params(), | optimizer = nn.Adam(loss_net.trainable_params(), | ||||
| @@ -16,7 +16,7 @@ | |||||
| """ | """ | ||||
| GCN training script. | GCN training script. | ||||
| """ | """ | ||||
| import os | |||||
| import time | import time | ||||
| import argparse | import argparse | ||||
| import ast | import ast | ||||
| @@ -27,6 +27,7 @@ from matplotlib import animation | |||||
| from sklearn import manifold | from sklearn import manifold | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from mindspore.train.serialization import save_checkpoint, load_checkpoint | |||||
| from src.gcn import GCN | from src.gcn import GCN | ||||
| from src.metrics import LossAccuracyWrapper, TrainNetWrapper | from src.metrics import LossAccuracyWrapper, TrainNetWrapper | ||||
| @@ -55,6 +56,8 @@ def train(): | |||||
| parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test') | parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test') | ||||
| parser.add_argument('--save_TSNE', type=ast.literal_eval, default=False, help='Whether to save t-SNE graph') | parser.add_argument('--save_TSNE', type=ast.literal_eval, default=False, help='Whether to save t-SNE graph') | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| if not os.path.exists("ckpts"): | |||||
| os.mkdir("ckpts") | |||||
| set_seed(args_opt.seed) | set_seed(args_opt.seed) | ||||
| context.set_context(mode=context.GRAPH_MODE, | context.set_context(mode=context.GRAPH_MODE, | ||||
| @@ -72,7 +75,6 @@ def train(): | |||||
| gcn_net.add_flags_recursive(fp16=True) | gcn_net.add_flags_recursive(fp16=True) | ||||
| eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay) | eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay) | ||||
| test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay) | |||||
| train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config) | train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config) | ||||
| loss_list = [] | loss_list = [] | ||||
| @@ -112,7 +114,12 @@ def train(): | |||||
| if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]): | if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]): | ||||
| print("Early stopping...") | print("Early stopping...") | ||||
| break | break | ||||
| save_checkpoint(gcn_net, "ckpts/gcn.ckpt") | |||||
| gcn_net_test = GCN(config, adj, feature, class_num) | |||||
| load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test) | |||||
| gcn_net_test.add_flags_recursive(fp16=True) | |||||
| test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay) | |||||
| t_test = time.time() | t_test = time.time() | ||||
| test_net.set_train(False) | test_net.set_train(False) | ||||
| test_result = test_net() | test_result = test_net() | ||||