| @@ -450,7 +450,7 @@ | |||
| 其中的每一个元素指定对应的输入/输出的Tensor分布策略,可参考: `mindspore.ops.Primitive.shard` 的描述,也可以设置为None,会默认以数据并行执行。 | |||
| 其余算子的并行策略由输入输出指定的策略推导得到。 | |||
| .. note:: 需设置为PyNative模式,并且全自动并行(AUTO_PARALLEL),同时设置 `set_auto_parallel_context` 中的搜索模式(search mode)为"sharding_propagation",或半自动并行(SEMI_AUTO_PARALLEL)。 | |||
| .. note:: 需设置为PyNative模式,并且全自动并行(AUTO_PARALLEL),同时设置 `set_auto_parallel_context` 中的搜索模式(search mode)为"sharding_propagation"。 | |||
| **参数:** | |||
| @@ -1236,6 +1236,7 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con | |||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||
| new_cnode->set_abstract(cnode->abstract()); | |||
| new_cnode->set_scope(cnode->scope()); | |||
| new_cnode->set_parallel(cnode->is_parallel()); | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) { | |||
| new_cnode->set_fullname_with_scope(cnode->input(kFirstDataInputIndex)->fullname_with_scope()); | |||
| } | |||
| @@ -1374,6 +1375,11 @@ void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(parameter_index); | |||
| size_t index = 0; | |||
| auto parallel_context = parallel::ParallelContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(parallel_context); | |||
| auto parallel_mode = parallel_context->parallel_mode(); | |||
| bool is_parallel_forward_ms_function = | |||
| !graph->is_bprop() && (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel); | |||
| for (const auto &input_node : graph->input_nodes()) { | |||
| auto params = common::AnfAlgo::GetAllOutput(input_node); | |||
| for (const auto ¶m : params) { | |||
| @@ -1386,13 +1392,14 @@ void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector | |||
| // Check shape of input and parameter | |||
| const auto &input_shape = input->shape(); | |||
| const auto ¶m_shape = common::AnfAlgo::GetOutputInferShape(param, 0); | |||
| if (input_shape.size() != param_shape.size()) { | |||
| if (!is_parallel_forward_ms_function && input_shape.size() != param_shape.size()) { | |||
| MS_LOG(EXCEPTION) << "Shape size of input tensor(" << input_shape << ") and parameter(" << param_shape | |||
| << ") are different, input index: " << index << ", parameter: " << param->DebugString(); | |||
| } | |||
| bool is_dynamic = param->Shape()->IsDynamic(); | |||
| for (size_t i = 0; i < input_shape.size(); i += 1) { | |||
| if (input_shape[i] < 0 || (static_cast<size_t>(input_shape[i]) != param_shape[i] && !is_dynamic)) { | |||
| if (input_shape[i] < 0 || (!is_parallel_forward_ms_function && | |||
| static_cast<size_t>(input_shape[i]) != param_shape[i] && !is_dynamic)) { | |||
| MS_LOG(EXCEPTION) << "Input tensor shape(" << input_shape << ") and parameter shape(" << param_shape | |||
| << ") are different, input index: " << index << ", parameter: " << param->DebugString(); | |||
| } | |||
| @@ -2648,7 +2655,29 @@ std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const | |||
| if (grads_count == 0) { | |||
| MS_LOG(EXCEPTION) << "Bprop graph has no grad"; | |||
| } | |||
| return {grads_count}; | |||
| uint32_t remove_number = 0; | |||
| auto parallel_context = parallel::ParallelContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(parallel_context); | |||
| auto parallel_mode = parallel_context->parallel_mode(); | |||
| if (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel) { | |||
| auto ret = graph->get_return(); | |||
| auto current_node = ret->cast<CNodePtr>(); | |||
| while (IsPrimitiveCNode(current_node->input(1), prim::kPrimMakeTuple)) { | |||
| current_node = current_node->input(1)->cast<CNodePtr>(); | |||
| } | |||
| auto inputs = current_node->inputs(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| auto node = inputs[i]; | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode->is_parallel()) { | |||
| remove_number += 1; | |||
| } | |||
| } | |||
| } | |||
| return {grads_count - remove_number}; | |||
| } | |||
| std::vector<uint32_t> bucket_size_list; | |||
| @@ -2705,7 +2734,8 @@ void SessionBasic::InitAllBucket(const KernelGraphPtr &graph, const device::Devi | |||
| auto parallel_context = parallel::ParallelContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(parallel_context); | |||
| auto parallel_mode = parallel_context->parallel_mode(); | |||
| if (!pynative_mode || parallel_mode != parallel::kDataParallel) { | |||
| if (!pynative_mode || (parallel_mode != parallel::kDataParallel && parallel_mode != parallel::kSemiAutoParallel && | |||
| parallel_mode != parallel::kAutoParallel)) { | |||
| return; | |||
| } | |||
| SetGraphBpropAttr(graph); | |||
| @@ -2747,7 +2777,8 @@ void SessionBasic::AddGradAddrToBucket(const GraphId &graph_id, const std::vecto | |||
| auto parallel_context = parallel::ParallelContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(parallel_context); | |||
| auto parallel_mode = parallel_context->parallel_mode(); | |||
| if (parallel_mode != parallel::kDataParallel) { | |||
| if (parallel_mode != parallel::kDataParallel && parallel_mode != parallel::kAutoParallel && | |||
| parallel_mode != parallel::kSemiAutoParallel) { | |||
| return; | |||
| } | |||
| @@ -32,6 +32,9 @@ | |||
| namespace mindspore { | |||
| namespace pipeline { | |||
| static const std::set<std::string> ELEMENT_WISE_NODE_ = {"Add", "BiasAdd", "ScalarAdd", "Sub", | |||
| "ScalarSub", "Mul", "ScalarMul", "RealDiv", | |||
| "ScalarDiv", "FloorDiv", "ScalarFloorDiv"}; | |||
| std::string GetWorldGroup() { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| @@ -238,6 +241,12 @@ bool CheckLayout(const ValueNodePtr &axes, bool *need_default_strategy, size_t * | |||
| return true; | |||
| } | |||
| bool IsElementWiseNode(const CNodePtr &cnode) { | |||
| auto prim = GetCNodePrimitive(cnode); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| return ELEMENT_WISE_NODE_.find(prim->name()) != ELEMENT_WISE_NODE_.end(); | |||
| } | |||
| void HandleStrategyForOneHot(std::vector<ValuePtr> *strategy) { | |||
| // onehot needs to set layout for output, modify the strategy with an additional dimension | |||
| auto input_strategy = GetValue<std::vector<int64_t>>(strategy->at(0)); | |||
| @@ -271,6 +280,47 @@ void HandleStrategyForMatMul(std::vector<ValuePtr> *strategy, const CNodePtr &cn | |||
| } | |||
| } | |||
| void HandleStrategyForElementWiseNode(std::vector<ValuePtr> *strategy, const CNodePtr &cnode) { | |||
| auto left_strategy = GetValue<std::vector<int64_t>>(strategy->at(0)); | |||
| auto right_strategy = GetValue<std::vector<int64_t>>(strategy->at(1)); | |||
| if (left_strategy.size() != right_strategy.size()) { | |||
| return; | |||
| } | |||
| int64_t strategy_mul = 1; | |||
| std::for_each(left_strategy.begin(), left_strategy.end(), [&](int64_t const &data) { strategy_mul *= data; }); | |||
| auto left_shape = cnode->input(1)->Shape()->cast<abstract::ShapePtr>(); | |||
| auto left_batch = left_shape->shape()[0]; | |||
| auto right_shape = cnode->input(2)->Shape()->cast<abstract::ShapePtr>(); | |||
| auto right_batch = right_shape->shape()[0]; | |||
| if (strategy_mul == 1) { | |||
| left_strategy = right_strategy; | |||
| } else { | |||
| right_strategy = left_strategy; | |||
| } | |||
| if (left_batch == 1) { | |||
| left_strategy[0] = 1; | |||
| } | |||
| if (right_batch == 1) { | |||
| right_strategy[0] = 1; | |||
| } | |||
| strategy->at(0) = MakeValue(left_strategy); | |||
| strategy->at(1) = MakeValue(right_strategy); | |||
| } | |||
| void HandleSpecialStrategy(std::vector<ValuePtr> *strategy, const CNodePtr &cnode) { | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimMatMul) || IsPrimitiveCNode(cnode, prim::kPrimBatchMatMul)) { | |||
| HandleStrategyForMatMul(strategy, cnode); | |||
| } | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimOneHot)) { | |||
| HandleStrategyForOneHot(strategy); | |||
| } | |||
| if (IsElementWiseNode(cnode)) { | |||
| HandleStrategyForElementWiseNode(strategy, cnode); | |||
| } | |||
| } | |||
| void GetInputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *input_nodes) { | |||
| auto parameters = func_graph->parameters(); | |||
| for (auto ¶meter : parameters) { | |||
| @@ -352,7 +402,6 @@ void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_axes, | |||
| auto attrs_temp = prim->attrs(); | |||
| ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements); | |||
| attrs_temp[parallel::OUT_STRATEGY] = strategy; | |||
| (void)prim->SetAttrs(attrs_temp); | |||
| } | |||
| } | |||
| @@ -424,12 +473,9 @@ void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_axes, c | |||
| } | |||
| for (auto &cnode : concerned_nodes) { | |||
| auto elements = GetStrategyElements(cnode, parameters, input_strategy); | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimMatMul) || IsPrimitiveCNode(cnode, prim::kPrimBatchMatMul)) { | |||
| HandleStrategyForMatMul(&elements, cnode); | |||
| } | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimOneHot)) { | |||
| HandleStrategyForOneHot(&elements); | |||
| } | |||
| // Some operators has a special requirements for parallel strategy | |||
| HandleSpecialStrategy(&elements, cnode); | |||
| // Set in_strategy | |||
| ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements); | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| auto attrs_temp = prim->attrs(); | |||
| @@ -2712,7 +2712,11 @@ void GradExecutor::MarkMsFunctionNodes(const pipeline::ResourcePtr &resource) { | |||
| auto grads = ret_cnode->input(1)->cast<CNodePtr>(); | |||
| for (size_t i = 1; i < grads->inputs().size(); i++) { | |||
| if (in_ms_function[i - 1]) { | |||
| auto cnode = grads->input(i)->cast<CNodePtr>(); | |||
| auto node = grads->input(i); | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| cnode->set_parallel(true); | |||
| } | |||
| } | |||
| @@ -35,6 +35,7 @@ from .tensor import COOTensor as MsCOOTensor | |||
| from .initializer import initializer | |||
| from .._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, COOTensor, PynativeExecutor_ | |||
| from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline | |||
| from ..parallel._tensor import _load_tensor_by_layout | |||
| from ..parallel._ps_context import _is_role_pserver, _is_role_sched | |||
| from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \ | |||
| _get_parameter_broadcast, _get_pipeline_stages | |||
| @@ -263,7 +264,12 @@ class _MindsporeFunctionExecutor: | |||
| for states in states_tuple: | |||
| for param, state in zip(params, states): | |||
| if param.shape != state.shape: | |||
| state.set_data(initializer(state.init, param.shape), True) | |||
| if state.has_init: | |||
| state.set_data(initializer("zeros", param.shape), True) | |||
| else: | |||
| layout = obj.parameter_layout_dict[param.name] | |||
| new_tensor = _load_tensor_by_layout(state.data, layout) | |||
| state.set_data(new_tensor, True) | |||
| _pynative_executor.get_top_cell().parameter_layout_dict = obj.parameter_layout_dict | |||
| def compile(self, args_list, method_name): | |||
| @@ -488,7 +488,7 @@ class Cell(Cell_): | |||
| Note: | |||
| Only effective in PYNATIVE_MODE and in either ParallelMode.AUTO_PARALLEL with | |||
| search_mode in auto_parallel_context set as sharding_propagation or ParallelMode.SEMI_AUTO_PARALLEL. | |||
| search_mode in auto_parallel_context set as sharding_propagation. | |||
| Args: | |||
| in_axes (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple | |||
| @@ -507,23 +507,23 @@ class Cell(Cell_): | |||
| >>> import mindspore.nn as nn | |||
| >>> | |||
| >>> class Block(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> self.dense1 = nn.Dense(10, 10) | |||
| >>> self.relu = nn.ReLU() | |||
| >>> self.dense2 = nn.Dense2(10, 10) | |||
| >>> def construct(self, x): | |||
| >>> x = self.relu(self.dense2(self.relu(self.dense1(x)))) | |||
| >>> return x | |||
| ... def __init__(self): | |||
| ... self.dense1 = nn.Dense(10, 10) | |||
| ... self.relu = nn.ReLU() | |||
| ... self.dense2 = nn.Dense2(10, 10) | |||
| ... def construct(self, x): | |||
| ... x = self.relu(self.dense2(self.relu(self.dense1(x)))) | |||
| ... return x | |||
| >>> | |||
| >>> class example(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> self.block1 = Block() | |||
| >>> self.block2 = Block() | |||
| >>> self.block2.shard(in_axes=(None, (2, 1)), out_axes=(None,)) | |||
| >>> def construct(self, x): | |||
| >>> x = self.block1(x) | |||
| >>> x = self.block2(x) | |||
| >>> return x | |||
| ... def __init__(self): | |||
| ... self.block1 = Block() | |||
| ... self.block2 = Block() | |||
| ... self.block2.shard(in_axes=((2, 1),), out_axes=(None,)) | |||
| ... def construct(self, x): | |||
| ... x = self.block1(x) | |||
| ... x = self.block2(x) | |||
| ... return x | |||
| """ | |||
| shard_fn = Shard() | |||
| fn = shard_fn(self, in_axes, out_axes, device, level) | |||
| @@ -803,8 +803,14 @@ class Shard(Shard_): | |||
| def __call__(self, fn, in_axes, out_axes, device="Ascend", level=0): | |||
| if context.get_context("mode") != context.PYNATIVE_MODE or \ | |||
| context.get_auto_parallel_context("parallel_mode") not in ["semi_auto_parallel", "auto_parallel"]: | |||
| raise AssertionError(f"'Shard' only supports semi_auto/auto parallel under PyNative mode") | |||
| context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel"]: | |||
| raise AssertionError(f"'Shard' only supports auto parallel under PyNative mode") | |||
| if context.get_context("device_target") not in ["Ascend"]: | |||
| raise AssertionError(f"'Shard' now only supports 'Ascend'") | |||
| if context.get_auto_parallel_context("full_batch"): | |||
| raise AssertionError(f"'Shard' doesn't support 'full_batch'. Please set 'full_batch' as False") | |||
| if context.get_auto_parallel_context("search_mode") != "sharding_propagation": | |||
| raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard'") | |||
| if not isinstance(in_axes, tuple): | |||
| raise TypeError(f"For 'Shard', the 'in_axes' should be a tuple, but got {type(in_axes).__name__}") | |||
| if not isinstance(out_axes, tuple): | |||
| @@ -355,6 +355,13 @@ class _AutoParallelContext: | |||
| ValueError: If parallel mode is not supported. | |||
| """ | |||
| self.check_context_handle() | |||
| run_mode = context.get_context("mode") | |||
| if run_mode == context.PYNATIVE_MODE and parallel_mode not in ( | |||
| context.ParallelMode.DATA_PARALLEL, context.ParallelMode.STAND_ALONE, | |||
| context.ParallelMode.AUTO_PARALLEL): | |||
| raise ValueError(f"Pynative Only support STAND_ALONE, DATA_PARALLEL and AUTO_PARALLEL under shard function" | |||
| f"for ParallelMode, " | |||
| f"but got {parallel_mode.upper()}.") | |||
| ret = self._context_handle.set_parallel_mode(parallel_mode) | |||
| if ret is False: | |||
| raise ValueError("The context configuration parameter 'parallel_mode' only support 'stand_alone', " | |||
| @@ -300,7 +300,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, | |||
| param_list = [] | |||
| for (key, value) in param_dict.items(): | |||
| each_param = {"name": key} | |||
| param_data = Tensor(value.data) | |||
| param_data = Tensor(value.data.asnumpy()) | |||
| # in automatic model parallel scenario, some parameters were split to all the devices, | |||
| # which should be combined before saving | |||
| @@ -422,7 +422,7 @@ def load(file_name, **kwargs): | |||
| dec_key = Validator.check_isinstance('dec_key', kwargs['dec_key'], bytes) | |||
| dec_mode = 'AES-GCM' | |||
| if 'dec_mode' in kwargs.keys(): | |||
| dec_mode = Validator.check_isinstance('dec_mode', kwargs['dec_mode'], str) | |||
| dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str) | |||
| graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode) | |||
| else: | |||
| graph = load_mindir(file_name) | |||