| @@ -366,26 +366,26 @@ bool CheckDeviceNum(const std::vector<std::vector<int64_t>> &strategies, const i | |||
| return true; | |||
| } | |||
| void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_axes, const int64_t &device_num) { | |||
| auto out_axes_tuple = out_axes->cast<ValueNodePtr>(); | |||
| void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_strategy, const int64_t &device_num) { | |||
| auto out_strategy_tuple = out_strategy->cast<ValueNodePtr>(); | |||
| bool need_default_strategy = false; | |||
| size_t out_axes_size = 0; | |||
| if (!IsValueNode<ValueTuple>(out_axes_tuple) || | |||
| !CheckLayout(out_axes_tuple, &need_default_strategy, &out_axes_size)) { | |||
| MS_LOG(EXCEPTION) << "out_axes should be a two-dimension tuple"; | |||
| size_t out_strategy_size = 0; | |||
| if (!IsValueNode<ValueTuple>(out_strategy_tuple) || | |||
| !CheckLayout(out_strategy_tuple, &need_default_strategy, &out_strategy_size)) { | |||
| MS_LOG(EXCEPTION) << "out_strategy should be a two-dimension tuple"; | |||
| } | |||
| std::vector<AnfNodePtr> output_nodes; | |||
| GetOutputNodes(func_graph, &output_nodes); | |||
| if (output_nodes.size() != out_axes_size) { | |||
| if (output_nodes.size() != out_strategy_size) { | |||
| MS_LOG(EXCEPTION) << "Output number: " << output_nodes.size() | |||
| << " is not equal to out_axes number: " << out_axes_size; | |||
| << " is not equal to out_strategy number: " << out_strategy_size; | |||
| } | |||
| std::vector<std::vector<int64_t>> output_strategy; | |||
| if (need_default_strategy) { | |||
| GenerateDefaultStrategy(out_axes_tuple, output_nodes, device_num, &output_strategy); | |||
| GenerateDefaultStrategy(out_strategy_tuple, output_nodes, device_num, &output_strategy); | |||
| } else { | |||
| output_strategy = GetValue<std::vector<std::vector<int64_t>>>(out_axes_tuple->value()); | |||
| output_strategy = GetValue<std::vector<std::vector<int64_t>>>(out_strategy_tuple->value()); | |||
| } | |||
| MS_LOG(WARNING) << "The output strategy will be overwritten as data-parallel"; | |||
| @@ -394,7 +394,8 @@ void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_axes, | |||
| auto output_shape = common::AnfAlgo::GetOutputInferShape(node, 0); | |||
| if (output_shape.size() != output_strategy[i].size()) { | |||
| MS_LOG(EXCEPTION) << "Output dimension: " << output_shape.size() | |||
| << " is not equal to out_axes dimension: " << output_strategy[i].size() << " at index " << i; | |||
| << " is not equal to out_strategy dimension: " << output_strategy[i].size() << " at index " | |||
| << i; | |||
| } | |||
| std::vector<ValuePtr> elements; | |||
| elements.push_back(MakeValue(output_strategy[i])); | |||
| @@ -430,24 +431,25 @@ std::vector<ValuePtr> GetStrategyElements(const CNodePtr &cnode, const std::vect | |||
| return elements; | |||
| } | |||
| void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_axes, const int64_t &device_num) { | |||
| auto in_axes_tuple = in_axes->cast<ValueNodePtr>(); | |||
| void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_strategy, const int64_t &device_num) { | |||
| auto in_strategy_tuple = in_strategy->cast<ValueNodePtr>(); | |||
| bool need_default_strategy = false; | |||
| size_t in_axes_size = 0; | |||
| if (!IsValueNode<ValueTuple>(in_axes_tuple) || !CheckLayout(in_axes_tuple, &need_default_strategy, &in_axes_size)) { | |||
| MS_LOG(EXCEPTION) << "in_axes should be a two-dimension tuple"; | |||
| size_t in_strategy_size = 0; | |||
| if (!IsValueNode<ValueTuple>(in_strategy_tuple) || | |||
| !CheckLayout(in_strategy_tuple, &need_default_strategy, &in_strategy_size)) { | |||
| MS_LOG(EXCEPTION) << "in_strategy should be a two-dimension tuple"; | |||
| } | |||
| std::vector<AnfNodePtr> input_nodes; | |||
| GetInputNodes(func_graph, &input_nodes); | |||
| if (input_nodes.size() != in_axes_size) { | |||
| if (input_nodes.size() != in_strategy_size) { | |||
| MS_LOG(EXCEPTION) << "Input numbers: " << input_nodes.size() | |||
| << " is not equal to in_axes numbers: " << in_axes_size; | |||
| << " is not equal to in_strategy numbers: " << in_strategy_size; | |||
| } | |||
| std::vector<std::vector<int64_t>> input_strategy; | |||
| if (need_default_strategy) { | |||
| GenerateDefaultStrategy(in_axes_tuple, input_nodes, device_num, &input_strategy); | |||
| GenerateDefaultStrategy(in_strategy_tuple, input_nodes, device_num, &input_strategy); | |||
| } else { | |||
| input_strategy = GetValue<std::vector<std::vector<int64_t>>>(in_axes_tuple->value()); | |||
| input_strategy = GetValue<std::vector<std::vector<int64_t>>>(in_strategy_tuple->value()); | |||
| } | |||
| if (!CheckDeviceNum(input_strategy, device_num)) { | |||
| MS_LOG(EXCEPTION) << "check device number failed"; | |||
| @@ -463,7 +465,7 @@ void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_axes, c | |||
| auto output_shape = common::AnfAlgo::GetOutputInferShape(parameter, 0); | |||
| if (output_shape.size() != input_strategy[i].size()) { | |||
| MS_LOG(EXCEPTION) << "Input dimension: " << output_shape.size() | |||
| << " is not equal to in_axes dimension: " << input_strategy[i].size() << " at index " << i; | |||
| << " is not equal to in_strategy dimension: " << input_strategy[i].size() << " at index " << i; | |||
| } | |||
| AnfNodeIndexSet param_sub_set = manager->node_users()[parameter]; | |||
| for (auto ¶m_pair : param_sub_set) { | |||
| @@ -492,13 +494,13 @@ void SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfNodePtr> | |||
| root->set_flag("auto_parallel", true); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto vnode = cnode->input(1)->cast<ValueNodePtr>(); | |||
| auto in_axes = cnode->input(2); | |||
| auto out_axes = cnode->input(3); | |||
| auto in_strategy = cnode->input(2); | |||
| auto out_strategy = cnode->input(3); | |||
| ScopeGuard scope_guard(vnode->scope()); | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(vnode); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| SetInputLayout(func_graph, in_axes, device_num); | |||
| SetOutputLayout(func_graph, out_axes, device_num); | |||
| SetInputLayout(func_graph, in_strategy, device_num); | |||
| SetOutputLayout(func_graph, out_strategy, device_num); | |||
| } | |||
| } | |||
| } | |||
| @@ -81,6 +81,19 @@ class Cell(Cell_): | |||
| [Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)] | |||
| """ | |||
| class CellGuard: | |||
| """Detecting whether the cell is a top-level cell with the 'with statement'.""" | |||
| def __enter__(self): | |||
| """Enter cell and increase recursion depth count.""" | |||
| _pynative_executor.set_lazy_build(True) | |||
| _pynative_executor.enter_cell() | |||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||
| """Exit cell and decrease recursion depth count.""" | |||
| _pynative_executor.exit_cell() | |||
| if _pynative_executor.is_top_cell(): | |||
| _pynative_executor.set_lazy_build(False) | |||
| IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names', | |||
| '_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run', | |||
| '_parameter_layout_dict', '_params_list', '_tensor_list', '_phase', '_auto_parallel_mode', | |||
| @@ -482,11 +495,11 @@ class Cell(Cell_): | |||
| for prim in all_prims: | |||
| prim.add_prim_attr("strategy_gen_mode", "data_parallel") | |||
| def shard(self, in_axes, out_axes, device="Ascend", level=0): | |||
| def shard(self, in_strategy, out_strategy, device="Ascend", level=0): | |||
| """ | |||
| Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be | |||
| generated by sharding propagation. In_axes and out_axes define the input and output layout respectively. | |||
| In_axes/Out_axes should be a tuple each element of which corresponds to the desired layout of | |||
| generated by sharding propagation. in_strategy and out_strategy define the input and output layout respectively. | |||
| in_strategy/out_strategy should be a tuple each element of which corresponds to the desired layout of | |||
| this input/output and None represents data_parallel. | |||
| Note: | |||
| @@ -494,9 +507,9 @@ class Cell(Cell_): | |||
| search_mode in auto_parallel_context set as sharding_propagation. | |||
| Args: | |||
| in_axes (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple | |||
| in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple | |||
| defines the layout of the corresponding input and None represents a data parallel strategy. | |||
| out_axes (tuple): Define the layout of outputs similar with in_axes. | |||
| out_strategy (tuple): Define the layout of outputs similar with in_strategy. | |||
| device (string): Select a certain device target. It is not in use right now. | |||
| Support ["CPU", "GPU", "Ascend"]. Default: "Ascend". | |||
| level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation | |||
| @@ -522,30 +535,17 @@ class Cell(Cell_): | |||
| ... def __init__(self): | |||
| ... self.block1 = Block() | |||
| ... self.block2 = Block() | |||
| ... self.block2.shard(in_axes=((2, 1),), out_axes=(None,)) | |||
| ... def construct(self, x): | |||
| ... x = self.block1(x) | |||
| ... x = self.block2(x) | |||
| ... return x | |||
| ... self.block2.shard(in_strategy=((2, 1),), out_strategy=(None,)) | |||
| ... def construct(self, x): | |||
| ... x = self.block1(x) | |||
| ... x = self.block2(x) | |||
| ... return x | |||
| """ | |||
| shard_fn = Shard() | |||
| fn = shard_fn(self, in_axes, out_axes, device, level) | |||
| fn = shard_fn(self, in_strategy, out_strategy, device, level) | |||
| object.__setattr__(self, "_shard_fn", fn) | |||
| return self | |||
| class CellGuard: | |||
| """Detecting whether the cell is a top-level cell with the 'with statement'.""" | |||
| def __enter__(self): | |||
| """Enter cell and increase recursion depth count.""" | |||
| _pynative_executor.set_lazy_build(True) | |||
| _pynative_executor.enter_cell() | |||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||
| """Exit cell and decrease recursion depth count.""" | |||
| _pynative_executor.exit_cell() | |||
| if _pynative_executor.is_top_cell(): | |||
| _pynative_executor.set_lazy_build(False) | |||
| def auto_cast_inputs(self, inputs): | |||
| """Auto cast inputs in mixed precision scenarios.""" | |||
| cast_inputs = inputs | |||
| @@ -800,12 +800,12 @@ class Shard(Shard_): | |||
| Shard_.__init__(self, 'Shard') | |||
| self.shard_fn = None | |||
| self.fn = None | |||
| self.in_axes = None | |||
| self.out_axes = None | |||
| self.in_strategy = None | |||
| self.out_strategy = None | |||
| self.device = None | |||
| self.level = None | |||
| def __call__(self, fn, in_axes, out_axes, device="Ascend", level=0): | |||
| def __call__(self, fn, in_strategy, out_strategy, device="Ascend", level=0): | |||
| if context.get_context("mode") != context.PYNATIVE_MODE or \ | |||
| context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel"]: | |||
| raise AssertionError(f"'Shard' only supports auto parallel under PyNative mode") | |||
| @@ -815,30 +815,30 @@ class Shard(Shard_): | |||
| raise AssertionError(f"'Shard' doesn't support 'full_batch'. Please set 'full_batch' as False") | |||
| if context.get_auto_parallel_context("search_mode") != "sharding_propagation": | |||
| raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard'") | |||
| if not isinstance(in_axes, tuple): | |||
| raise TypeError(f"For 'Shard', the 'in_axes' should be a tuple, but got {type(in_axes).__name__}") | |||
| if not isinstance(out_axes, tuple): | |||
| raise TypeError(f"For 'Shard', the 'out_axes' should be a tuple, " | |||
| f"but got {type(out_axes).__name__}") | |||
| if not isinstance(in_strategy, tuple): | |||
| raise TypeError(f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}") | |||
| if not isinstance(out_strategy, tuple): | |||
| raise TypeError(f"For 'Shard', the 'out_strategy' should be a tuple, " | |||
| f"but got {type(out_strategy).__name__}") | |||
| if not isinstance(device, str): | |||
| raise TypeError(f"For 'Shard', the 'device' should be a string, " | |||
| f"but got {type(device).__name__}") | |||
| if not isinstance(level, int): | |||
| raise TypeError(f"For 'Shard', the 'level' should be an integer, " | |||
| f"but got {type(level).__name__}") | |||
| if self.shard_fn is not None and self.fn == fn and self.in_axes == in_axes and self.out_axes == out_axes and \ | |||
| self.device == device and self.level == level: | |||
| if self.shard_fn is not None and self.fn == fn and self.in_strategy == in_strategy and \ | |||
| self.out_strategy == out_strategy and self.device == device and self.level == level: | |||
| return self.shard_fn | |||
| shard_ = Shard() | |||
| @ms_function(obj=fn) | |||
| def after_shard(*args): | |||
| return shard_(fn, in_axes, out_axes, device, level)(*args) | |||
| return shard_(fn, in_strategy, out_strategy, device, level)(*args) | |||
| self.shard_fn = after_shard | |||
| self.fn = fn | |||
| self.in_axes = in_axes | |||
| self.out_axes = out_axes | |||
| self.in_strategy = in_strategy | |||
| self.out_strategy = out_strategy | |||
| self.device = device | |||
| self.level = level | |||
| return self.shard_fn | |||
| @@ -582,8 +582,8 @@ def vjp(fn, inputs, v): | |||
| shard_fn = Shard() | |||
| def shard(fn, in_axes, out_axes, device="Ascend", level=0): | |||
| return shard_fn(fn, in_axes, out_axes, device, level) | |||
| def shard(fn, in_strategy, out_strategy, device="Ascend", level=0): | |||
| return shard_fn(fn, in_strategy, out_strategy, device, level) | |||
| def arange(start=0, stop=None, step=1, rtype=None): | |||
| @@ -183,7 +183,7 @@ class ResNet(nn.Cell): | |||
| in_channel=in_channels[0], | |||
| out_channel=out_channels[0], | |||
| stride=strides[0]) | |||
| self.layer1.shard(in_axes=(None,), out_axes=(None,)) | |||
| self.layer1.shard(in_strategy=(None,), out_strategy=(None,)) | |||
| self.layer2 = self._make_layer(block, | |||
| layer_nums[1], | |||
| in_channel=in_channels[1], | |||
| @@ -194,7 +194,7 @@ class ResNet(nn.Cell): | |||
| in_channel=in_channels[2], | |||
| out_channel=out_channels[2], | |||
| stride=strides[2]) | |||
| self.layer3.shard(in_axes=((8, 1, 1, 1),), out_axes=(None,)) | |||
| self.layer3.shard(in_strategy=((8, 1, 1, 1),), out_strategy=(None,)) | |||
| self.layer4 = self._make_layer(block, | |||
| layer_nums[3], | |||
| in_channel=in_channels[3], | |||
| @@ -205,7 +205,7 @@ class ResNet(nn.Cell): | |||
| self.end_point = nn.Dense(2048, num_classes, has_bias=True, | |||
| weight_init=weight_variable(), | |||
| bias_init=weight_variable()).add_flags_recursive(fp16=True) | |||
| self.head = F.shard(self.end_point, in_axes=((1, 8),), out_axes=(None,)) | |||
| self.head = F.shard(self.end_point, in_strategy=((1, 8),), out_strategy=(None,)) | |||
| self.squeeze = P.Squeeze() | |||
| self.cast = P.Cast() | |||
| @@ -376,7 +376,7 @@ def test_train_feed(num_classes=65536): | |||
| dataset = ds.GeneratorDataset(dataset, column_names=["image", "label"]) | |||
| net = resnet50(num_classes) | |||
| loss = SoftmaxCrossEntropyExpand(sparse=True) | |||
| loss.shard(in_axes=(None, None), out_axes=(None,)) | |||
| loss.shard(in_strategy=(None, None), out_strategy=(None,)) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) | |||
| model = Model(net, loss_fn=loss, optimizer=opt) | |||
| model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback) | |||
| @@ -42,21 +42,21 @@ class NetMatMul(nn.Cell): | |||
| return self.matmul(x, y) | |||
| class Net(nn.Cell): | |||
| def __init__(self, in_axes, out_axes): | |||
| def __init__(self, in_strategy, out_strategy): | |||
| super().__init__() | |||
| self.mul_net = NetMul() | |||
| self.matmul_net = NetMatMul() | |||
| self.mul_net.shard(in_axes=in_axes, out_axes=out_axes) | |||
| self.mul_net.shard(in_strategy=in_strategy, out_strategy=out_strategy) | |||
| def construct(self, x, y): | |||
| out1 = self.matmul_net(x, y) | |||
| out2 = self.matmul_net(x, y) | |||
| return self.mul_net(out1, out2) | |||
| def cell_shard_execution(in_axes, out_axes, error_log): | |||
| net = Net(in_axes, out_axes) | |||
| def cell_shard_execution(in_strategy, out_strategy, error_log): | |||
| net = Net(in_strategy, out_strategy) | |||
| x = Tensor(np.ones([128, 128]), dtype=ms.float32) | |||
| y = Tensor(np.ones([128, 128]), dtype=ms.float32) | |||
| @@ -65,63 +65,66 @@ def cell_shard_execution(in_axes, out_axes, error_log): | |||
| assert error_log in str(err.value) | |||
| def test_in_axes_numbers_check(): | |||
| def test_in_strategy_numbers_check(): | |||
| """ | |||
| Feature: shard function for cell | |||
| Description: inconsistent input number and in_axes number | |||
| Expectation: throw an exception indicating inconsistent input number and in_axes number | |||
| Description: inconsistent input number and in_strategy number | |||
| Expectation: throw an exception indicating inconsistent input number and in_strategy number | |||
| """ | |||
| set_context() | |||
| in_axes = ((8, 1), None, (1, 8)) | |||
| out_axes = (None,) | |||
| error_log = "Input numbers: 2 is not equal to in_axes numbers: 3" | |||
| cell_shard_execution(in_axes, out_axes, error_log) | |||
| in_strategy = ((8, 1), None, (1, 8)) | |||
| out_strategy = (None,) | |||
| error_log = "Input numbers: 2 is not equal to in_strategy numbers: 3" | |||
| cell_shard_execution(in_strategy, out_strategy, error_log) | |||
| def test_out_axes_numbers_check(): | |||
| def test_out_strategy_numbers_check(): | |||
| """ | |||
| Feature: shard function for cell | |||
| Description: inconsistent output number and out_axes number | |||
| Expectation: throw an exception indicating inconsistent output number and out_axes number | |||
| Description: inconsistent output number and out_strategy number | |||
| Expectation: throw an exception indicating inconsistent output number and out_strategy number | |||
| """ | |||
| set_context() | |||
| in_axes = ((8, 1), None) | |||
| out_axes = (None, (8, 1)) | |||
| error_log = "Output number: 1 is not equal to out_axes number: 2" | |||
| cell_shard_execution(in_axes, out_axes, error_log) | |||
| in_strategy = ((8, 1), None) | |||
| out_strategy = (None, (8, 1)) | |||
| error_log = "Output number: 1 is not equal to out_strategy number: 2" | |||
| cell_shard_execution(in_strategy, out_strategy, error_log) | |||
| def test_in_axes_dimension_check(): | |||
| def test_in_strategy_dimension_check(): | |||
| """ | |||
| Feature: shard function for cell | |||
| Description: inconsistent input dimension and in_axes dimension | |||
| Expectation: throw an exception indicating inconsistent input_dimension and in_axes dimension | |||
| Description: inconsistent input dimension and in_strategy dimension | |||
| Expectation: throw an exception indicating inconsistent input_dimension and in_strategy dimension | |||
| """ | |||
| set_context() | |||
| in_axes = ((8, 1, 1), None) | |||
| out_axes = (None, (8, 1)) | |||
| error_log = "Input dimension: 2 is not equal to in_axes dimension: 3 at index 0" | |||
| cell_shard_execution(in_axes, out_axes, error_log) | |||
| in_strategy = ((8, 1, 1), None) | |||
| out_strategy = (None, (8, 1)) | |||
| error_log = "Input dimension: 2 is not equal to in_strategy dimension: 3 at index 0" | |||
| cell_shard_execution(in_strategy, out_strategy, error_log) | |||
| def test_out_axes_dimension_check(): | |||
| def test_out_strategy_dimension_check(): | |||
| """ | |||
| Feature: shard function for cell | |||
| Description: inconsistent output dimension and out_axes dimension | |||
| Expectation: throw an exception indicating inconsistent output_dimension and out_axes dimension | |||
| Description: inconsistent output dimension and out_strategy dimension | |||
| Expectation: throw an exception indicating inconsistent output_dimension and out_strategy dimension | |||
| """ | |||
| set_context() | |||
| in_axes = ((8, 1), None) | |||
| out_axes = ((8,),) | |||
| error_log = "Output dimension: 2 is not equal to out_axes dimension: 1 at index 0" | |||
| cell_shard_execution(in_axes, out_axes, error_log) | |||
| in_strategy = ((8, 1), None) | |||
| out_strategy = ((8,),) | |||
| error_log = "Output dimension: 2 is not equal to out_strategy dimension: 1 at index 0" | |||
| cell_shard_execution(in_strategy, out_strategy, error_log) | |||
| def test_in_axes_format_check(): | |||
| def test_in_strategy_format_check(): | |||
| """ | |||
| Feature: shard function for cell | |||
| Description: unsupported in_axes format | |||
| Expectation: throw an exception indicating an supported in_axes format | |||
| Description: unsupported in_strategy format | |||
| Expectation: throw an exception indicating an supported in_strategy format | |||
| """ | |||
| set_context() | |||
| in_axes = ([8, 1], None) | |||
| out_axes = (None,) | |||
| error_log = "in_axes should be a two-dimension tuple" | |||
| cell_shard_execution(in_axes, out_axes, error_log) | |||
| in_strategy = ([8, 1], None) | |||
| out_strategy = (None,) | |||
| error_log = "in_strategy should be a two-dimension tuple" | |||
| cell_shard_execution(in_strategy, out_strategy, error_log) | |||