Merge pull request !5696 from Xiaoda/20-moving-multi-graph-interface-internaltags/v1.0.0
| @@ -17,7 +17,5 @@ This interface is ONLY used in Auto-parallel procedure. | |||
| """ | |||
| from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \ | |||
| set_algo_parameters | |||
| from ._cost_model_context import set_multi_subgraphs, get_multi_subgraphs | |||
| __all__ = ["set_multi_subgraphs", "get_multi_subgraphs", | |||
| "get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] | |||
| __all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] | |||
| @@ -589,7 +589,7 @@ def reset_cost_model_context(): | |||
| """Reset cost model context attributes.""" | |||
| cost_model_context().reset_cost_model() | |||
| def set_multi_subgraphs(multi_subgraph=True): | |||
| def _set_multi_subgraphs(multi_subgraph=True): | |||
| """ | |||
| Set the flag of ANF graph containing multiple subgraphs. | |||
| @@ -598,7 +598,7 @@ def set_multi_subgraphs(multi_subgraph=True): | |||
| """ | |||
| cost_model_context().set_multi_subgraphs(multi_subgraph) | |||
| def get_multi_subgraphs(): | |||
| def _get_multi_subgraphs(): | |||
| """ | |||
| Get the flag of ANF graph containing multiple subgraphs. | |||
| """ | |||
| @@ -32,6 +32,7 @@ from .. import nn | |||
| from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | |||
| from ..context import ParallelMode | |||
| from ..parallel._utils import _need_to_full, _to_full_tensor | |||
| from ..parallel._cost_model_context import _set_multi_subgraphs | |||
| from ..common import dtype as mstype | |||
| from .dataset_helper import DatasetHelper | |||
| from . import amp | |||
| @@ -166,6 +167,9 @@ class Model: | |||
| if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||
| network.set_auto_parallel() | |||
| if self._optimizer is None: | |||
| # In this case, multiple optimizer(s) is supposed to be included in 'self._network' | |||
| _set_multi_subgraphs() | |||
| return network | |||
| def _build_eval_network(self, metrics, eval_network, eval_indexes): | |||
| @@ -190,6 +194,9 @@ class Model: | |||
| if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||
| if self._optimizer: | |||
| self._eval_network = _VirtualDatasetCell(self._eval_network) | |||
| if self._optimizer is None: | |||
| # In this case, multiple optimizer(s) is supposed to be included in 'self._network' | |||
| _set_multi_subgraphs() | |||
| self._eval_network.set_auto_parallel() | |||
| def _build_predict_network(self): | |||
| @@ -197,6 +204,7 @@ class Model: | |||
| self._predict_network = self._network | |||
| if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||
| self._predict_network = _VirtualDatasetCell(self._network) | |||
| _set_multi_subgraphs() | |||
| self._predict_network.set_auto_parallel() | |||
| def _clear_metrics(self): | |||
| @@ -22,7 +22,6 @@ from mindspore import Model, context | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.communication.management import get_rank, get_group_size, init | |||
| from mindspore.parallel import set_multi_subgraphs | |||
| from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple | |||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | |||
| @@ -145,7 +144,6 @@ if __name__ == "__main__": | |||
| device_target=wide_deep_config.device_target, save_graphs=True) | |||
| context.set_context(variable_memory_max_size="24GB") | |||
| context.set_context(enable_sparse=True) | |||
| set_multi_subgraphs() | |||
| init() | |||
| if wide_deep_config.host_device_mix == 1: | |||
| context.set_auto_parallel_context( | |||
| @@ -21,7 +21,6 @@ from mindspore import Model, context | |||
| from mindspore.train.callback import TimeMonitor | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.communication.management import get_rank, get_group_size, init | |||
| from mindspore.parallel import set_multi_subgraphs | |||
| from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple | |||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | |||
| @@ -33,7 +32,6 @@ from src.config import WideDeepConfig | |||
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) | |||
| set_multi_subgraphs() | |||
| init() | |||
| @@ -17,13 +17,13 @@ import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter, ParameterTuple | |||
| from mindspore import context | |||
| from mindspore import context, Model | |||
| from mindspore.common.api import _executor | |||
| from mindspore.nn.optim import Adam, FTRL | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore.parallel import set_multi_subgraphs | |||
| from mindspore.parallel._cost_model_context import _set_multi_subgraphs | |||
| from mindspore.parallel._utils import _reset_op_id as reset_op_id | |||
| @@ -103,7 +103,7 @@ class TrainStepWarp(nn.Cell): | |||
| def test_double_subgraphs(): | |||
| set_multi_subgraphs() | |||
| _set_multi_subgraphs() | |||
| context.set_context(save_graphs=True) | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| @@ -120,3 +120,50 @@ def test_double_subgraphs(): | |||
| 'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]} | |||
| assert strategies == expected_strategies | |||
| class DatasetLenet(): | |||
| def __init__(self, predict, label, length=3): | |||
| self.predict = predict | |||
| self.label = label | |||
| self.index = 0 | |||
| self.length = length | |||
| def __iter__(self): | |||
| return self | |||
| def __next__(self): | |||
| if self.index >= self.length: | |||
| raise StopIteration | |||
| self.index += 1 | |||
| return self.predict | |||
| def reset(self): | |||
| self.index = 0 | |||
| def get_dataset_size(self): | |||
| return 32 | |||
| def get_repeat_count(self): | |||
| return 1 | |||
| def create_tuple_iterator(self): | |||
| return self | |||
| def test_double_subgraphs_train(): | |||
| context.set_context(save_graphs=True) | |||
| context.set_auto_parallel_context(device_num=1, global_rank=0) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net = TrainStepWarp(NetWithLoss(Net())) | |||
| batch_ids = np.ones([8, 8, 8, 8]).astype(np.int32) | |||
| ds_train = DatasetLenet(Tensor(batch_ids), None) | |||
| model = Model(net) | |||
| model.train(1, ds_train, dataset_sink_mode=False) | |||
| strategies = _executor._get_strategy(net) | |||
| expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op3': [[1, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/ReLU-op4': [[1, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/Mul-op5': [[1, 1, 1, 1], [1, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/Mul-op6': [[1, 1, 1, 1], [1, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/Cast-op1': [[1, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/ReduceSum-op7': [[1, 1, 1, 1]]} | |||
| assert strategies == expected_strategies | |||