| @@ -738,27 +738,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v | |||
| inputs.emplace_back(input_node); | |||
| } | |||
| } | |||
| auto const_input_index = prim->get_const_input_indexes(); | |||
| bool have_const_input = !const_input_index.empty(); | |||
| bool is_const_prim = prim->is_const_prim(); | |||
| MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value " | |||
| << prim->is_const_prim(); | |||
| bool is_const_input = | |||
| have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end(); | |||
| if (abs == nullptr || is_const_prim || is_const_input) { | |||
| MS_LOG(DEBUG) << "MakeCnode get node no in map " << id; | |||
| ValuePtr input_value = PyAttrValue(obj); | |||
| abs = input_value->ToAbstract(); | |||
| if (!is_const_prim && !is_const_input) { | |||
| auto config = abstract::AbstractBase::kBroadenTensorOnly; | |||
| abs = abs->Broaden(config); | |||
| MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config; | |||
| } | |||
| node_abs_map_[id] = abs; | |||
| } | |||
| (*args_spec_list).emplace_back(abs); | |||
| (*args_spec_list).emplace_back(CheckConstValue(prim, obj, abs, id, i)); | |||
| } | |||
| CNodePtr cnode = nullptr; | |||
| @@ -770,6 +750,34 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v | |||
| return cnode; | |||
| } | |||
| abstract::AbstractBasePtr PynativeExecutor::CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj, | |||
| const abstract::AbstractBasePtr &abs, const std::string &id, | |||
| size_t index) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto const_input_index = prim->get_const_input_indexes(); | |||
| bool have_const_input = !const_input_index.empty(); | |||
| bool is_const_prim = prim->is_const_prim(); | |||
| auto new_abs = abs; | |||
| MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value " | |||
| << prim->is_const_prim(); | |||
| bool is_const_input = | |||
| have_const_input && std::find(const_input_index.begin(), const_input_index.end(), index) != const_input_index.end(); | |||
| if (abs == nullptr || is_const_prim || is_const_input) { | |||
| MS_LOG(DEBUG) << "MakeCnode get node no in map " << id; | |||
| ValuePtr input_value = PyAttrValue(obj); | |||
| MS_EXCEPTION_IF_NULL(input_value); | |||
| new_abs = input_value->ToAbstract(); | |||
| if (!is_const_prim && !is_const_input) { | |||
| auto config = abstract::AbstractBase::kBroadenTensorOnly; | |||
| MS_EXCEPTION_IF_NULL(new_abs); | |||
| new_abs = new_abs->Broaden(config); | |||
| MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config; | |||
| } | |||
| node_abs_map_[id] = new_abs; | |||
| } | |||
| return new_abs; | |||
| } | |||
| void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, | |||
| const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) { | |||
| MS_EXCEPTION_IF_NULL(is_find); | |||
| @@ -1004,6 +1012,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { | |||
| return free_param; | |||
| } | |||
| node = graph_info->node_map.at(obj_id).first; | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_LOG(DEBUG) << "Get input param node " << node->ToString() << " " << obj_id; | |||
| return node; | |||
| } | |||
| @@ -2008,9 +2017,14 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar | |||
| top_cell_id_ = cell_id; | |||
| in_grad_process_ = true; | |||
| // update forward already run flag with previous top cell | |||
| std::string input_args_id; | |||
| for (size_t i = 0; i < args.size(); ++i) { | |||
| input_args_id = input_args_id + GetId(args[i]) + "_"; | |||
| } | |||
| auto pre_top_cell = GetTopCell(cell_id); | |||
| if (pre_top_cell != nullptr) { | |||
| pre_top_cell->forward_already_run = true; | |||
| pre_top_cell->input_args_id = input_args_id; | |||
| } | |||
| auto df_builder = std::make_shared<FuncGraph>(); | |||
| auto graph_info = std::make_shared<GraphInfo>(cell_id); | |||
| @@ -2019,6 +2033,7 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar | |||
| resource->results()[pipeline::kPynativeGraphId] = graph_id_++; | |||
| auto top_cell_info = std::make_shared<TopCellInfo>(true, resource, df_builder, cell_id); | |||
| top_cell_info->forward_already_run = true; | |||
| top_cell_info->input_args_id = input_args_id; | |||
| if (!IsTopestGraph(cell_id)) { | |||
| top_cell_info->top_cell_index = cell_graph_list_.size(); | |||
| top_cell_index_ = top_cell_info->top_cell_index; | |||
| @@ -2862,11 +2877,24 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & | |||
| } | |||
| py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) { | |||
| bool forward_run = false; | |||
| const auto &cell_id = GetCellId(cell, args); | |||
| // Checkout whether top cell has already run. | |||
| std::string input_args_id; | |||
| for (size_t i = 0; i < args.size(); ++i) { | |||
| input_args_id = input_args_id + GetId(args[i]) + "_"; | |||
| } | |||
| auto top_cell = GetTopCell(cell_id); | |||
| bool forward_run = false; | |||
| if (top_cell != nullptr) { | |||
| forward_run = top_cell->forward_already_run; | |||
| if (!top_cell->input_args_id.empty() && top_cell->input_args_id != input_args_id && top_cell->forward_already_run && | |||
| CheckDynamicCell(cell_id)) { | |||
| MS_LOG(WARNING) << "The construct of running cell is dynamic and the input info of this cell has changed, " | |||
| "forward process will run again"; | |||
| top_cell->forward_already_run = false; | |||
| top_cell->input_args_id = input_args_id; | |||
| } else { | |||
| forward_run = top_cell->forward_already_run; | |||
| } | |||
| if (forward_run) { | |||
| top_cell_index_ = top_cell->top_cell_index; | |||
| } | |||
| @@ -107,6 +107,7 @@ class TopCellInfo { | |||
| std::string cell_id; | |||
| std::string sens_id; | |||
| std::string weights_id; | |||
| std::string input_args_id; | |||
| }; | |||
| using GraphInfoPtr = std::shared_ptr<GraphInfo>; | |||
| @@ -209,6 +210,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||
| AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); | |||
| AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, | |||
| abstract::AbstractBasePtrList *args_spec_list); | |||
| abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj, | |||
| const abstract::AbstractBasePtr &abs, const std::string &id, size_t index); | |||
| void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list, | |||
| bool *is_find); | |||
| void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode); | |||
| @@ -307,6 +307,23 @@ class Cell(Cell_): | |||
| res.append(cast(item, dst_type)) | |||
| return tuple(res) | |||
| def do_parameter_broadcast(self): | |||
| if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL: | |||
| if not self.parameter_broadcast_done: | |||
| _pynative_exec.parameter_broadcast(self, self.phase, self._auto_parallel_mode) | |||
| self.parameter_broadcast_done = True | |||
| def run_construct(self, cast_inputs, kwargs): | |||
| if self.enable_hook: | |||
| _pynative_exec.enter_construct(self) | |||
| output = self._hook_construct(*cast_inputs, **kwargs) | |||
| _pynative_exec.leave_construct(self) | |||
| else: | |||
| _pynative_exec.enter_construct(self) | |||
| output = self.construct(*cast_inputs, **kwargs) | |||
| _pynative_exec.leave_construct(self) | |||
| return output | |||
| def __call__(self, *inputs, **kwargs): | |||
| if self.__class__.construct is Cell.construct: | |||
| logger.warning(f"The '{self.__class__}' does not override the method 'construct', " | |||
| @@ -324,11 +341,7 @@ class Cell(Cell_): | |||
| out = self.compile_and_run(*inputs) | |||
| return out | |||
| if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL: | |||
| if not self.parameter_broadcast_done: | |||
| _pynative_exec.parameter_broadcast(self, self.phase, self._auto_parallel_mode) | |||
| self.parameter_broadcast_done = True | |||
| self.do_parameter_broadcast() | |||
| for item in inputs: | |||
| if isinstance(item, numpy.ndarray): | |||
| raise TypeError("cell inputs should not be numpy array.") | |||
| @@ -349,14 +362,7 @@ class Cell(Cell_): | |||
| cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32) | |||
| if not cast_inputs: | |||
| cast_inputs = inputs | |||
| if self.enable_hook: | |||
| _pynative_exec.enter_construct(self) | |||
| output = self._hook_construct(*cast_inputs, **kwargs) | |||
| _pynative_exec.leave_construct(self) | |||
| else: | |||
| _pynative_exec.enter_construct(self) | |||
| output = self.construct(*cast_inputs, **kwargs) | |||
| _pynative_exec.leave_construct(self) | |||
| output = self.run_construct(cast_inputs, kwargs) | |||
| if isinstance(output, Parameter): | |||
| output = output.data | |||
| if self.requires_grad is True: | |||
| @@ -17,7 +17,7 @@ Wrap cells for networks. | |||
| Use the Wrapper to combine the loss or build the training steps. | |||
| """ | |||
| from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \ | |||
| from .cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \ | |||
| ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple | |||
| from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell | |||
| from .grad_reducer import DistributedGradReducer | |||
| @@ -25,6 +25,7 @@ from ..layer.timedistributed import TimeDistributed | |||
| __all__ = [ | |||
| "TimeDistributed", | |||
| "ForwardValueAndGrad", | |||
| "TrainOneStepCell", | |||
| "WithLossCell", | |||
| "WithGradCell", | |||
| @@ -13,9 +13,12 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Cell_wrapper.""" | |||
| from types import FunctionType, MethodType | |||
| from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, | |||
| _get_parallel_mode) | |||
| from mindspore.context import ParallelMode | |||
| from ...common.tensor import Tensor | |||
| from ...common import dtype as mstype | |||
| from ...common.parameter import Parameter, ParameterTuple | |||
| from ...ops import composite as C | |||
| @@ -174,6 +177,107 @@ class WithGradCell(Cell): | |||
| return grads | |||
| class ForwardValueAndGrad(Cell): | |||
| r""" | |||
| Network training package class. | |||
| Including the network and a gradient function. The resulting Cell is trained with input '\*inputs'. | |||
| The backward graph will be created in the gradient function to calculating gradient. | |||
| Args: | |||
| network (Cell): The training network. The network only supports single output. | |||
| weights (ParameterTuple): The parameters of the training network that need to calculate the gradient | |||
| get_all (bool): If True, get all the gradients with respect to inputs. Default: False. | |||
| get_by_list (bool): If True, get all the gradients with respect to Parameter variables. | |||
| If get_all and get_by_list are both False, get the gradient with respect to first input. | |||
| If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables | |||
| at the same time in the form of ((gradients with respect to inputs), | |||
| (gradients with respect to parameters)). Default: False. | |||
| sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input. | |||
| If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. | |||
| Default: False. | |||
| If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through | |||
| the location parameter or key-value pair parameter. If the value is transferred through the key-value pair | |||
| parameter, the key must be sens. | |||
| Inputs: | |||
| - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. | |||
| - sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. | |||
| Outputs: | |||
| - **forward value** (a scalar Tensor with shape :math:`()`) - The result of network forward running. | |||
| - **gradients** (tuple(tensor)) - The gradients of network parameters and inputs. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU````CPU`` | |||
| Examples: | |||
| >>> inputs = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32)) | |||
| >>> labels = Tensor(np.ones([32]).astype(np.int32)) | |||
| >>> net = Net() | |||
| >>> weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters())) | |||
| >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> #1) Using the WithLossCell existing provide | |||
| >>> loss_net = nn.WithLossCell(net, loss_fn) | |||
| >>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True) | |||
| >>> loss, grads = forward_value_and_grad(inputs, labels, 1.0) | |||
| >>> | |||
| >>> #2) Using user-defined WithLossCell | |||
| >>> class MyWithLossCell(Cell): | |||
| ... def __init__(self, backbone, loss_fn): | |||
| ... super(MyWithLossCell, self).__init__(auto_prefix=False) | |||
| ... self._backbone = backbone | |||
| ... self._loss_fn = loss_fn | |||
| ... | |||
| ... def construct(self, x, y, label): | |||
| ... out = self._backbone(x, y) | |||
| ... return self._loss_fn(out, label) | |||
| ... | |||
| ... @property | |||
| ... def backbone_network(self): | |||
| ... return self._backbone | |||
| ... | |||
| >>> loss_net = MyWithLossCell(net, loss_fn) | |||
| >>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True) | |||
| >>> loss, grads = forward_value_and_grad(inputs, labels, 1.0) | |||
| """ | |||
| def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False): | |||
| super(ForwardValueAndGrad, self).__init__(auto_prefix=False) | |||
| if not isinstance(network, (Cell, FunctionType, MethodType)): | |||
| raise TypeError(f"The type of training network should be cell, function type or method type, " | |||
| f"but got '{type(network)}'") | |||
| if get_by_list and not isinstance(weights, ParameterTuple): | |||
| raise TypeError(f"When get_by_list is set to True, the parameters of training network should be " | |||
| f"ParameterTuple type, but got '{type(weights)}'") | |||
| if get_by_list is not True and weights is not None: | |||
| raise TypeError(f"When get_by_list is set to False, the parameters of training network should be " | |||
| f"NoneType, but got '{type(weights)}'") | |||
| self.network = network | |||
| self.network.set_grad() | |||
| self.weights = weights | |||
| self.get_all = get_all | |||
| self.get_by_list = get_by_list | |||
| self.sens_param = sens_param | |||
| self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param) | |||
| def construct(self, *inputs): | |||
| weights = self.weights | |||
| if self.sens_param: | |||
| sens = inputs[-1] | |||
| inputs = inputs[:-1] | |||
| else: | |||
| sens = None | |||
| loss = self.network(*inputs) | |||
| if self.sens_param: | |||
| if not isinstance(sens, Tensor): | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), sens) | |||
| grads = self.grad(self.network, weights)(*inputs, sens) | |||
| else: | |||
| grads = self.grad(self.network, weights)(*inputs) | |||
| return loss, grads | |||
| class TrainOneStepCell(Cell): | |||
| r""" | |||
| Network training package class. | |||
| @@ -22,10 +22,10 @@ import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import Tensor, ParameterTuple | |||
| from mindspore import amp | |||
| from mindspore.nn import Dense | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell, ForwardValueAndGrad | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.nn.layer.basic import Flatten | |||
| from mindspore.nn.layer.conv import Conv2d | |||
| @@ -33,6 +33,7 @@ from mindspore.nn.layer.normalization import BatchNorm2d | |||
| from mindspore.nn.layer.pooling import MaxPool2d | |||
| from mindspore.nn.optim import Momentum | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops.operations import Add | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| @@ -399,3 +400,53 @@ def test_trainTensor_amp(num_classes=10, epoch=18, batch_size=16): | |||
| assert (losses[-1][0].asnumpy() < 1) | |||
| assert not losses[-1][1].asnumpy() | |||
| assert (losses[-1][2].asnumpy() > 1) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_trainTensor_with_new_interface(num_classes=10, epoch=8, batch_size=1): | |||
| net = resnet50(num_classes) | |||
| criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||
| net_with_criterion = WithLossCell(net, criterion) | |||
| net_with_criterion.set_train() | |||
| weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters())) | |||
| optimizer = Momentum(weights, 0.1, 0.9) | |||
| train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True) | |||
| losses = [] | |||
| for i in range(0, epoch): | |||
| data = Tensor(np.ones([batch_size, 3, 224, 224] | |||
| ).astype(np.float32) * 0.01) | |||
| label = Tensor(np.ones([batch_size]).astype(np.int32)) | |||
| loss, grads = train_network(data, label, 1.0) | |||
| grads = F.identity(grads) | |||
| optimizer(grads) | |||
| losses.append(loss) | |||
| assert (losses[-1].asnumpy() < 0.8) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_big_batchSize_with_new_interface(num_classes=10, epoch=8, batch_size=338): | |||
| net = resnet50(num_classes) | |||
| criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||
| net_with_criterion = WithLossCell(net, criterion) | |||
| net_with_criterion.set_train() | |||
| weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters())) | |||
| optimizer = Momentum(weights, 0.1, 0.9) | |||
| train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True) | |||
| losses = [] | |||
| for i in range(0, epoch): | |||
| data = Tensor(np.ones([batch_size, 3, 224, 224] | |||
| ).astype(np.float32) * 0.01) | |||
| label = Tensor(np.ones([batch_size]).astype(np.int32)) | |||
| loss, grads = train_network(data, label, 1.0) | |||
| grads = F.identity(grads) | |||
| optimizer(grads) | |||
| losses.append(loss) | |||
| assert (losses[-1].asnumpy() < 0.8) | |||
| @@ -164,3 +164,40 @@ def test_ascend_pynative_lenet(): | |||
| print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) | |||
| assert loss_output.asnumpy() < 0.004 | |||
| assert loss_output.asnumpy() > 0.003 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_pynative_lenet_with_new_interface(): | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| epoch_size = 20 | |||
| batch_size = 32 | |||
| inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32)) | |||
| labels = Tensor(np.ones([batch_size]).astype(np.int32)) | |||
| net = LeNet() | |||
| criterion = CrossEntropyLoss() | |||
| net_with_criterion = WithLossCell(net, criterion) | |||
| net_with_criterion.set_train() | |||
| weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters())) | |||
| optimizer = Momentum(weights, 0.1, 0.9) | |||
| forward_value_and_grad = nn.ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True) | |||
| total_time = 0 | |||
| for epoch in range(0, epoch_size): | |||
| start_time = time.time() | |||
| loss_output, grads = forward_value_and_grad(inputs, labels) | |||
| optimizer(grads) | |||
| end_time = time.time() | |||
| cost_time = end_time - start_time | |||
| total_time = total_time + cost_time | |||
| print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) | |||
| assert loss_output.asnumpy() < 0.005 | |||
| assert loss_output.asnumpy() > 0.003 | |||