| @@ -15,13 +15,9 @@ | |||||
| """GCN.""" | """GCN.""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import nn | from mindspore import nn | ||||
| from mindspore.common.parameter import ParameterTuple | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.nn.layer.activation import get_activation | from mindspore.nn.layer.activation import get_activation | ||||
| from model_zoo.gcn.src.metrics import Loss, Accuracy | |||||
| def glorot(shape): | def glorot(shape): | ||||
| @@ -105,116 +101,3 @@ class GCN(nn.Cell): | |||||
| output0 = self.layer0(self.adj, self.feature) | output0 = self.layer0(self.adj, self.feature) | ||||
| output1 = self.layer1(self.adj, output0) | output1 = self.layer1(self.adj, output0) | ||||
| return output1 | return output1 | ||||
| class LossAccuracyWrapper(nn.Cell): | |||||
| """ | |||||
| Wraps the GCN model with loss and accuracy cell. | |||||
| Args: | |||||
| network (Cell): GCN network. | |||||
| label (numpy.ndarray): Dataset labels. | |||||
| mask (numpy.ndarray): Mask for training, evaluation or test. | |||||
| weight_decay (float): Weight decay parameter for weight of the first convolution layer. | |||||
| """ | |||||
| def __init__(self, network, label, mask, weight_decay): | |||||
| super(LossAccuracyWrapper, self).__init__() | |||||
| self.network = network | |||||
| self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) | |||||
| self.accuracy = Accuracy(label, mask) | |||||
| def construct(self): | |||||
| preds = self.network() | |||||
| loss = self.loss(preds) | |||||
| accuracy = self.accuracy(preds) | |||||
| return loss, accuracy | |||||
| class LossWrapper(nn.Cell): | |||||
| """ | |||||
| Wraps the GCN model with loss. | |||||
| Args: | |||||
| network (Cell): GCN network. | |||||
| label (numpy.ndarray): Dataset labels. | |||||
| mask (numpy.ndarray): Mask for training. | |||||
| weight_decay (float): Weight decay parameter for weight of the first convolution layer. | |||||
| """ | |||||
| def __init__(self, network, label, mask, weight_decay): | |||||
| super(LossWrapper, self).__init__() | |||||
| self.network = network | |||||
| self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) | |||||
| def construct(self): | |||||
| preds = self.network() | |||||
| loss = self.loss(preds) | |||||
| return loss | |||||
| class TrainOneStepCell(nn.Cell): | |||||
| r""" | |||||
| Network training package class. | |||||
| Wraps the network with an optimizer. The resulting Cell be trained without inputs. | |||||
| Backward graph will be created in the construct function to do parameter updating. Different | |||||
| parallel modes are available to run the training. | |||||
| Args: | |||||
| network (Cell): The training network. | |||||
| optimizer (Cell): Optimizer for updating the weights. | |||||
| sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. | |||||
| Outputs: | |||||
| Tensor, a scalar Tensor with shape :math:`()`. | |||||
| Examples: | |||||
| >>> net = Net() | |||||
| >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() | |||||
| >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| >>> loss_net = nn.WithLossCell(net, loss_fn) | |||||
| >>> train_net = nn.TrainOneStepCell(loss_net, optim) | |||||
| """ | |||||
| def __init__(self, network, optimizer, sens=1.0): | |||||
| super(TrainOneStepCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.network.add_flags(defer_inline=True) | |||||
| self.weights = ParameterTuple(network.trainable_params()) | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | |||||
| self.sens = sens | |||||
| def construct(self): | |||||
| weights = self.weights | |||||
| loss = self.network() | |||||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||||
| grads = self.grad(self.network, weights)(sens) | |||||
| return F.depend(loss, self.optimizer(grads)) | |||||
| class TrainNetWrapper(nn.Cell): | |||||
| """ | |||||
| Wraps the GCN model with optimizer. | |||||
| Args: | |||||
| network (Cell): GCN network. | |||||
| label (numpy.ndarray): Dataset labels. | |||||
| mask (numpy.ndarray): Mask for training, evaluation or test. | |||||
| config (ConfigGCN): Configuration for GCN. | |||||
| """ | |||||
| def __init__(self, network, label, mask, config): | |||||
| super(TrainNetWrapper, self).__init__(auto_prefix=True) | |||||
| self.network = network | |||||
| loss_net = LossWrapper(network, label, mask, config.weight_decay) | |||||
| optimizer = nn.Adam(loss_net.trainable_params(), | |||||
| learning_rate=config.learning_rate) | |||||
| self.loss_train_net = TrainOneStepCell(loss_net, optimizer) | |||||
| self.accuracy = Accuracy(label, mask) | |||||
| def construct(self): | |||||
| loss = self.loss_train_net() | |||||
| accuracy = self.accuracy(self.network()) | |||||
| return loss, accuracy | |||||
| @@ -17,6 +17,9 @@ from mindspore import nn | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common.parameter import ParameterTuple | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import functional as F | |||||
| class Loss(nn.Cell): | class Loss(nn.Cell): | ||||
| @@ -68,3 +71,116 @@ class Accuracy(nn.Cell): | |||||
| mask = mask / mask_reduce | mask = mask / mask_reduce | ||||
| accuracy_all *= mask | accuracy_all *= mask | ||||
| return self.mean(accuracy_all) | return self.mean(accuracy_all) | ||||
| class LossAccuracyWrapper(nn.Cell): | |||||
| """ | |||||
| Wraps the GCN model with loss and accuracy cell. | |||||
| Args: | |||||
| network (Cell): GCN network. | |||||
| label (numpy.ndarray): Dataset labels. | |||||
| mask (numpy.ndarray): Mask for training, evaluation or test. | |||||
| weight_decay (float): Weight decay parameter for weight of the first convolution layer. | |||||
| """ | |||||
| def __init__(self, network, label, mask, weight_decay): | |||||
| super(LossAccuracyWrapper, self).__init__() | |||||
| self.network = network | |||||
| self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) | |||||
| self.accuracy = Accuracy(label, mask) | |||||
| def construct(self): | |||||
| preds = self.network() | |||||
| loss = self.loss(preds) | |||||
| accuracy = self.accuracy(preds) | |||||
| return loss, accuracy | |||||
| class LossWrapper(nn.Cell): | |||||
| """ | |||||
| Wraps the GCN model with loss. | |||||
| Args: | |||||
| network (Cell): GCN network. | |||||
| label (numpy.ndarray): Dataset labels. | |||||
| mask (numpy.ndarray): Mask for training. | |||||
| weight_decay (float): Weight decay parameter for weight of the first convolution layer. | |||||
| """ | |||||
| def __init__(self, network, label, mask, weight_decay): | |||||
| super(LossWrapper, self).__init__() | |||||
| self.network = network | |||||
| self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) | |||||
| def construct(self): | |||||
| preds = self.network() | |||||
| loss = self.loss(preds) | |||||
| return loss | |||||
| class TrainOneStepCell(nn.Cell): | |||||
| r""" | |||||
| Network training package class. | |||||
| Wraps the network with an optimizer. The resulting Cell be trained without inputs. | |||||
| Backward graph will be created in the construct function to do parameter updating. Different | |||||
| parallel modes are available to run the training. | |||||
| Args: | |||||
| network (Cell): The training network. | |||||
| optimizer (Cell): Optimizer for updating the weights. | |||||
| sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. | |||||
| Outputs: | |||||
| Tensor, a scalar Tensor with shape :math:`()`. | |||||
| Examples: | |||||
| >>> net = Net() | |||||
| >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() | |||||
| >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| >>> loss_net = nn.WithLossCell(net, loss_fn) | |||||
| >>> train_net = nn.TrainOneStepCell(loss_net, optim) | |||||
| """ | |||||
| def __init__(self, network, optimizer, sens=1.0): | |||||
| super(TrainOneStepCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.network.add_flags(defer_inline=True) | |||||
| self.weights = ParameterTuple(network.trainable_params()) | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | |||||
| self.sens = sens | |||||
| def construct(self): | |||||
| weights = self.weights | |||||
| loss = self.network() | |||||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||||
| grads = self.grad(self.network, weights)(sens) | |||||
| return F.depend(loss, self.optimizer(grads)) | |||||
| class TrainNetWrapper(nn.Cell): | |||||
| """ | |||||
| Wraps the GCN model with optimizer. | |||||
| Args: | |||||
| network (Cell): GCN network. | |||||
| label (numpy.ndarray): Dataset labels. | |||||
| mask (numpy.ndarray): Mask for training, evaluation or test. | |||||
| config (ConfigGCN): Configuration for GCN. | |||||
| """ | |||||
| def __init__(self, network, label, mask, config): | |||||
| super(TrainNetWrapper, self).__init__(auto_prefix=True) | |||||
| self.network = network | |||||
| loss_net = LossWrapper(network, label, mask, config.weight_decay) | |||||
| optimizer = nn.Adam(loss_net.trainable_params(), | |||||
| learning_rate=config.learning_rate) | |||||
| self.loss_train_net = TrainOneStepCell(loss_net, optimizer) | |||||
| self.accuracy = Accuracy(label, mask) | |||||
| def construct(self): | |||||
| loss = self.loss_train_net() | |||||
| accuracy = self.accuracy(self.network()) | |||||
| return loss, accuracy | |||||
| @@ -26,9 +26,10 @@ from matplotlib import animation | |||||
| from sklearn import manifold | from sklearn import manifold | ||||
| from mindspore import context | from mindspore import context | ||||
| from model_zoo.gcn.src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper | |||||
| from model_zoo.gcn.src.config import ConfigGCN | |||||
| from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask | |||||
| from src.gcn import GCN | |||||
| from src.metrics import LossAccuracyWrapper, TrainNetWrapper | |||||
| from src.config import ConfigGCN | |||||
| from src.dataset import get_adj_features_labels, get_mask | |||||
| def t_SNE(out_feature, dim): | def t_SNE(out_feature, dim): | ||||
| @@ -17,7 +17,8 @@ import time | |||||
| import pytest | import pytest | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import context | from mindspore import context | ||||
| from model_zoo.gcn.src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper | |||||
| from model_zoo.gcn.src.gcn import GCN | |||||
| from model_zoo.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper | |||||
| from model_zoo.gcn.src.config import ConfigGCN | from model_zoo.gcn.src.config import ConfigGCN | ||||
| from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask | from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask | ||||