| @@ -44,7 +44,14 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { | |||||
| auto device_arrangement = tensor_layout->device_arrangement().array(); | auto device_arrangement = tensor_layout->device_arrangement().array(); | ||||
| auto tensor_map = tensor_layout->tensor_map().array(); | auto tensor_map = tensor_layout->tensor_map().array(); | ||||
| auto slice_shape = tensor_layout->slice_shape().array(); | auto slice_shape = tensor_layout->slice_shape().array(); | ||||
| std::vector<std::vector<int32_t>> layout = {device_arrangement, tensor_map, slice_shape}; | |||||
| int32_t _field_size = tensor_layout->get_field_size(); | |||||
| std::vector<int32_t> field_size; | |||||
| if (_field_size != 0) { | |||||
| field_size.push_back(_field_size); | |||||
| } else { | |||||
| field_size = {0}; | |||||
| } | |||||
| std::vector<std::vector<int32_t>> layout = {device_arrangement, tensor_map, slice_shape, field_size}; | |||||
| dict[py::str(name)] = layout; | dict[py::str(name)] = layout; | ||||
| MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); | MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); | ||||
| } | } | ||||
| @@ -105,6 +105,17 @@ Status MatMulBase::GetAttrs() { | |||||
| } | } | ||||
| } | } | ||||
| auto field_size_iter = attrs_.find(FIELD_SIZE); | |||||
| if (field_size_iter != attrs_.end()) { | |||||
| MS_EXCEPTION_IF_NULL(field_size_iter->second); | |||||
| if (field_size_iter->second->isa<Int32Imm>()) { | |||||
| field_size_ = field_size_iter->second->cast<Int32ImmPtr>()->value(); | |||||
| } else { | |||||
| MS_LOG(ERROR) << name_ << " : The value of field_size is not int."; | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| // infer inputs dimension size | // infer inputs dimension size | ||||
| if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) { | if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) { | ||||
| MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; | MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; | ||||
| @@ -346,6 +357,10 @@ Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (field_size_ != 0) { | |||||
| mat_b_layout.set_field_size(field_size_); | |||||
| } | |||||
| inputs_layout->push_back(mat_a_layout); | inputs_layout->push_back(mat_a_layout); | ||||
| inputs_layout->push_back(mat_b_layout); | inputs_layout->push_back(mat_b_layout); | ||||
| outputs_layout->push_back(output_layout); | outputs_layout->push_back(output_layout); | ||||
| @@ -62,6 +62,7 @@ class MatMulBase : public OperatorInfo { | |||||
| bool transpose_a_ = false; | bool transpose_a_ = false; | ||||
| bool transpose_b_ = false; | bool transpose_b_ = false; | ||||
| bool forward_reduce_scatter_ = false; | bool forward_reduce_scatter_ = false; | ||||
| int32_t field_size_ = 0; | |||||
| size_t mat_a_dimension_ = 0; | size_t mat_a_dimension_ = 0; | ||||
| size_t mat_b_dimension_ = 0; | size_t mat_b_dimension_ = 0; | ||||
| }; | }; | ||||
| @@ -100,6 +100,7 @@ constexpr char CONCAT_DIM[] = "concat_dim"; | |||||
| constexpr char FORWARD[] = "forward"; | constexpr char FORWARD[] = "forward"; | ||||
| constexpr char BACKWARD[] = "backward"; | constexpr char BACKWARD[] = "backward"; | ||||
| constexpr char REDISTRIBUTION[] = "redistribution"; | constexpr char REDISTRIBUTION[] = "redistribution"; | ||||
| constexpr char SKIP_REDISTRIBUTION[] = "skip_redistribution"; | |||||
| constexpr char REPLACE[] = "replace"; | constexpr char REPLACE[] = "replace"; | ||||
| constexpr char CONNSYMBOL[] = "/"; | constexpr char CONNSYMBOL[] = "/"; | ||||
| constexpr char INSTANCE_NAME[] = "instance_name"; | constexpr char INSTANCE_NAME[] = "instance_name"; | ||||
| @@ -131,6 +132,7 @@ constexpr char FORWARD_OP[] = "forward_op"; | |||||
| constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; | constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; | ||||
| constexpr char DARA_PARALLEL[] = "data_parallel"; | constexpr char DARA_PARALLEL[] = "data_parallel"; | ||||
| constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; | constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; | ||||
| constexpr char FIELD_SIZE[] = "field_size"; | |||||
| constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; | constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; | ||||
| constexpr char DEVICE[] = "Device"; | constexpr char DEVICE[] = "Device"; | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <utility> | |||||
| #include "frontend/parallel/device_manager.h" | #include "frontend/parallel/device_manager.h" | ||||
| #include "frontend/parallel/device_matrix.h" | #include "frontend/parallel/device_matrix.h" | ||||
| @@ -145,17 +146,23 @@ Status ReshapeInfo::ComputeReplaceOp() { | |||||
| MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString(); | MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString(); | ||||
| MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString(); | MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString(); | ||||
| MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size(); | MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size(); | ||||
| RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); | |||||
| if (redistribution_oplist_ptr == nullptr) { | |||||
| if (is_generating_costs_) { | |||||
| MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed."; | |||||
| } else { | |||||
| MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; | |||||
| if (is_skip_) { | |||||
| ConstructOperator constructor; | |||||
| replace_op_ = constructor.SkipRedisReshapeOP(output_layout_.slice_shape().array()); | |||||
| replace_op_info_.clear(); | |||||
| } else { | |||||
| RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); | |||||
| if (redistribution_oplist_ptr == nullptr) { | |||||
| if (is_generating_costs_) { | |||||
| MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed."; | |||||
| } else { | |||||
| MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; | |||||
| } | |||||
| return FAILED; | |||||
| } | } | ||||
| return FAILED; | |||||
| replace_op_ = redistribution_oplist_ptr->first; | |||||
| replace_op_info_ = redistribution_oplist_ptr->second; | |||||
| } | } | ||||
| replace_op_ = redistribution_oplist_ptr->first; | |||||
| replace_op_info_ = redistribution_oplist_ptr->second; | |||||
| MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size(); | MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size(); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -255,6 +262,19 @@ Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayout | |||||
| } | } | ||||
| Status ReshapeInfo::InferTensorInfo() { | Status ReshapeInfo::InferTensorInfo() { | ||||
| // skip reshape infer if skip_redistribution is true | |||||
| if (is_skip_) { | |||||
| TensorLayout layout; | |||||
| Shape shape; | |||||
| Shape slice_shape; | |||||
| layout.set_skip_redistribution(true); | |||||
| TensorInfo tensor_info_in(layout, shape, slice_shape); | |||||
| inputs_tensor_info_.push_back(tensor_info_in); | |||||
| outputs_tensor_info_.push_back(tensor_info_in); | |||||
| MS_LOG(DEBUG) << name() << "skip redistribution reshape InferTensorInfo"; | |||||
| return SUCCESS; | |||||
| } | |||||
| Shapes inputs_slice_shape, outputs_slice_shape; | Shapes inputs_slice_shape, outputs_slice_shape; | ||||
| Strategys inputs_strategy = strategy_->GetInputDim(); | Strategys inputs_strategy = strategy_->GetInputDim(); | ||||
| Strategys outputs_strategy = GetOutputsStrategy(); | Strategys outputs_strategy = GetOutputsStrategy(); | ||||
| @@ -316,6 +336,16 @@ Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const l | |||||
| } | } | ||||
| Status ReshapeInfo::Init(const StrategyPtr &strategy) { | Status ReshapeInfo::Init(const StrategyPtr &strategy) { | ||||
| auto reshape_skip_redis_iter = attrs_.find(SKIP_REDISTRIBUTION); | |||||
| if (reshape_skip_redis_iter != attrs_.end()) { | |||||
| MS_EXCEPTION_IF_NULL(reshape_skip_redis_iter->second); | |||||
| if (!reshape_skip_redis_iter->second->isa<BoolImm>()) { | |||||
| MS_LOG(ERROR) << name_ << ": skip_redistribution is not a bool."; | |||||
| return FAILED; | |||||
| } | |||||
| is_skip_ = reshape_skip_redis_iter->second->cast<BoolImmPtr>()->value(); | |||||
| } | |||||
| ResetQueueMember(); | ResetQueueMember(); | ||||
| device_number(strategy); | device_number(strategy); | ||||
| if (strategy) { | if (strategy) { | ||||
| @@ -98,6 +98,7 @@ class ReshapeInfo : public OperatorInfo { | |||||
| bool input_layout_set_flag_; | bool input_layout_set_flag_; | ||||
| bool output_layout_set_flag_; | bool output_layout_set_flag_; | ||||
| bool is_generating_costs_; | bool is_generating_costs_; | ||||
| bool is_skip_ = false; | |||||
| std::string pre_operator_name_; | std::string pre_operator_name_; | ||||
| std::string next_operator_name_; | std::string next_operator_name_; | ||||
| }; | }; | ||||
| @@ -302,16 +302,26 @@ void Redistribution(const std::pair<AnfNodePtr, int> &node_pair, const OperatorI | |||||
| MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString(); | MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString(); | ||||
| // extract tensor layout in and out | // extract tensor layout in and out | ||||
| if (distribute_operator->outputs_tensor_info().empty()) { | if (distribute_operator->outputs_tensor_info().empty()) { | ||||
| MS_LOG(EXCEPTION) << "Failure:pre_node's tensorinfo_in is empty"; | |||||
| MS_LOG(WARNING) << "pre_node's tensorinfo_in is empty, operator name is " << distribute_operator->name(); | |||||
| return; | |||||
| } | } | ||||
| if (IntToSize(index - 1) >= next_distribute_operator->inputs_tensor_info().size()) { | if (IntToSize(index - 1) >= next_distribute_operator->inputs_tensor_info().size()) { | ||||
| MS_LOG(EXCEPTION) << "The index is out of range, the index is " << index - 1 << ", the vector size is " | |||||
| << next_distribute_operator->inputs_tensor_info().size(); | |||||
| MS_LOG(WARNING) << "The index is out of range, the index is " << index - 1 << ", the vector size is " | |||||
| << next_distribute_operator->inputs_tensor_info().size() << "next operator name is " | |||||
| << next_distribute_operator->name(); | |||||
| return; | |||||
| } | } | ||||
| TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[IntToSize(index - 1)]; | TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[IntToSize(index - 1)]; | ||||
| TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); | TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); | ||||
| TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator); | TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator); | ||||
| if (tensorlayout_in.skip_redistribution() || tensorlayout_out.skip_redistribution()) { | |||||
| MS_LOG(INFO) << "skip the reshape redistribution, operator name is" << distribute_operator->name() | |||||
| << "next distribute operator, operator name is" << next_distribute_operator->name(); | |||||
| return; | |||||
| } | |||||
| if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) { | if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) { | ||||
| MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name; | MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name; | ||||
| MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " | MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " | ||||
| @@ -28,6 +28,19 @@ Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix | |||||
| return Status::SUCCESS; | return Status::SUCCESS; | ||||
| } | } | ||||
| // skip redistribution for reshape operator | |||||
| OperatorVector ConstructOperator::SkipRedisReshapeOP(Shape shape) { | |||||
| OperatorAttrs attrs; | |||||
| ValuePtr param_value = MakeValue(shape); | |||||
| Attr param = std::make_pair(SHAPE, param_value); | |||||
| OperatorParams params = {std::make_pair(param, 2)}; | |||||
| OperatorArgs args = std::make_pair(attrs, params); | |||||
| Operator op = std::make_pair(RESHAPE, args); | |||||
| OperatorVector opvector; | |||||
| opvector.push_back(op); | |||||
| return opvector; | |||||
| } | |||||
| Status ConstructOperator::ReshapeOP(Shape shape) { | Status ConstructOperator::ReshapeOP(Shape shape) { | ||||
| int32_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); | int32_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); | ||||
| int32_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies<int>()); | int32_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies<int>()); | ||||
| @@ -35,6 +35,7 @@ class ConstructOperator { | |||||
| ConstructOperator() : dev_size_(0) {} | ConstructOperator() : dev_size_(0) {} | ||||
| ~ConstructOperator() = default; | ~ConstructOperator() = default; | ||||
| Status Init(const RankList &dev_list, const Shape &dev_matrix_shape); | Status Init(const RankList &dev_list, const Shape &dev_matrix_shape); | ||||
| OperatorVector SkipRedisReshapeOP(Shape shape); | |||||
| Status ReshapeOP(Shape shape); | Status ReshapeOP(Shape shape); | ||||
| Status StridedSliceOP(Args args); | Status StridedSliceOP(Args args); | ||||
| Status AllGatherOP(int32_t dev_dim); | Status AllGatherOP(int32_t dev_dim); | ||||
| @@ -41,6 +41,14 @@ class TensorLayout { | |||||
| Status InitFromVector(const std::vector<int32_t> &device_arrangement, const std::vector<int32_t> &tensor_map, | Status InitFromVector(const std::vector<int32_t> &device_arrangement, const std::vector<int32_t> &tensor_map, | ||||
| const std::vector<int32_t> &tensor_shape); | const std::vector<int32_t> &tensor_shape); | ||||
| bool skip_redistribution() const { return skip_redistribution_; } | |||||
| void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; } | |||||
| int32_t get_field_size() const { return field_size_; } | |||||
| void set_field_size(int32_t field_size) { field_size_ = field_size; } | |||||
| Arrangement device_arrangement() const { return device_arrangement_; } | Arrangement device_arrangement() const { return device_arrangement_; } | ||||
| Map tensor_map() const { return tensor_map_; } | Map tensor_map() const { return tensor_map_; } | ||||
| @@ -92,6 +100,8 @@ class TensorLayout { | |||||
| Arrangement device_arrangement_; | Arrangement device_arrangement_; | ||||
| Map tensor_map_; | Map tensor_map_; | ||||
| Arrangement tensor_shape_; | Arrangement tensor_shape_; | ||||
| bool skip_redistribution_ = false; | |||||
| int32_t field_size_ = 0; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -247,8 +247,8 @@ class Parameter: | |||||
| if not isinstance(layout, list): | if not isinstance(layout, list): | ||||
| raise TypeError("The layout should be list! layout is {}." | raise TypeError("The layout should be list! layout is {}." | ||||
| .format(layout)) | .format(layout)) | ||||
| if len(layout) != 3: | |||||
| raise ValueError("The length of layout must be 3! layout is {}." | |||||
| if len(layout) < 3: | |||||
| raise ValueError("The length of layout must be larger than 3! layout is {}." | |||||
| .format(layout)) | .format(layout)) | ||||
| slice_index = int(_get_slice_index(layout[0], layout[1])) | slice_index = int(_get_slice_index(layout[0], layout[1])) | ||||
| self.default_input = self.init_mode.to_tensor(slice_index, layout[2]) | self.default_input = self.init_mode.to_tensor(slice_index, layout[2]) | ||||
| @@ -229,8 +229,8 @@ def _load_tensor_by_layout(tensor, layout): | |||||
| """ | """ | ||||
| if not isinstance(layout, list): | if not isinstance(layout, list): | ||||
| raise TypeError("The layout should be list! layout is {}".format(layout)) | raise TypeError("The layout should be list! layout is {}".format(layout)) | ||||
| if len(layout) != 3: | |||||
| raise ValueError("The length of layout must be 3! layout is {}".format(layout)) | |||||
| if len(layout) < 3: | |||||
| raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout)) | |||||
| dev_mat = layout[0] | dev_mat = layout[0] | ||||
| tensor_map = layout[1] | tensor_map = layout[1] | ||||
| if tensor.size() == 1: | if tensor.size() == 1: | ||||
| @@ -290,3 +290,37 @@ def _reshape_param_data(param_data, dev_mat, tensor_map): | |||||
| tensor_slices_new = tensor_slices_new_inner | tensor_slices_new = tensor_slices_new_inner | ||||
| return Tensor(tensor_slices_new[0]) | return Tensor(tensor_slices_new[0]) | ||||
| def _reshape_param_data_with_weight(param_data, dev_mat, field_size): | |||||
| """ | |||||
| Combine param slice by the device matrix, used in model parallel scenario. | |||||
| Args: | |||||
| param_data (Tensor): The tensor to be reshaped and rearrangement, | |||||
| generated from all the device from AllGatherParamNet. | |||||
| dev_mat (list): The device matrix of devices. | |||||
| Returns: | |||||
| Tensor, the combined tensor which with the whole data value. | |||||
| Examples: | |||||
| >>> param_data = _allgather_param_net(param_data) | |||||
| >>> dev_mat = [2, 2] | |||||
| >>> field_size = [39] | |||||
| >>> tensor = _reshape_param_data_with_weight(param_data, dev_mat, field_size) | |||||
| """ | |||||
| device_count = 1 | |||||
| for dim in dev_mat: | |||||
| device_count *= dim | |||||
| tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0) | |||||
| tensor_slices_col = [] | |||||
| for i in range(len(tensor_slices[0][0])): | |||||
| tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size[0], -1) | |||||
| for j in range(1, device_count): | |||||
| tensor_slices_new = np.concatenate((tensor_slices_new,\ | |||||
| np.array(tensor_slices[j][:, i]).reshape(field_size[0], -1)), axis=1) | |||||
| tensor_slices_col.append(tensor_slices_new) | |||||
| new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1) | |||||
| for i in range(1, len(tensor_slices_col)): | |||||
| new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1) | |||||
| return Tensor(new_tensor) | |||||
| @@ -359,14 +359,17 @@ def _get_merged_param_data(net, param_name, param_data): | |||||
| dev_mat = layout[0] | dev_mat = layout[0] | ||||
| tensor_map = layout[1] | tensor_map = layout[1] | ||||
| field_size = layout[3] | |||||
| from mindspore.parallel._cell_wrapper import get_allgather_cell | from mindspore.parallel._cell_wrapper import get_allgather_cell | ||||
| from mindspore.parallel._tensor import _reshape_param_data | |||||
| from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight | |||||
| # while any dim is not equal to -1, means param is splited and needs to be merged | # while any dim is not equal to -1, means param is splited and needs to be merged | ||||
| for dim in tensor_map: | for dim in tensor_map: | ||||
| if dim != -1: | if dim != -1: | ||||
| allgather_net = get_allgather_cell() | allgather_net = get_allgather_cell() | ||||
| param_data = allgather_net(param_data) | param_data = allgather_net(param_data) | ||||
| if field_size[0]: | |||||
| return _reshape_param_data_with_weight(param_data, dev_mat, field_size) | |||||
| return _reshape_param_data(param_data, dev_mat, tensor_map) | return _reshape_param_data(param_data, dev_mat, tensor_map) | ||||
| return param_data | return param_data | ||||
| @@ -49,8 +49,8 @@ def test_get_parameter_layout(): | |||||
| net.set_auto_parallel() | net.set_auto_parallel() | ||||
| exe = me._executor | exe = me._executor | ||||
| exe.compile(net, x, phase='train', auto_parallel_mode=True) | exe.compile(net, x, phase='train', auto_parallel_mode=True) | ||||
| x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1] | |||||
| weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1] | |||||
| x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1] | |||||
| weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1] | |||||
| expect_dict = {'x': x_layout, 'w1': weight_layout} | expect_dict = {'x': x_layout, 'w1': weight_layout} | ||||
| # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut | # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut | ||||
| assert net.parameter_layout_dict == expect_dict | assert net.parameter_layout_dict == expect_dict | ||||
| @@ -0,0 +1,58 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| from mindspore import context, Tensor, Parameter | |||||
| from mindspore.common.api import _executor | |||||
| from mindspore.nn import Cell, TrainOneStepCell, Momentum | |||||
| from mindspore.ops import operations as P | |||||
| class Net(Cell): | |||||
| def __init__(self, matmul_weight, strategy1=None): | |||||
| super().__init__() | |||||
| self.gatherv2 = P.GatherV2().set_strategy(strategy1) | |||||
| self.reshape = P.Reshape().add_prim_attr("skip_redistribution", True) | |||||
| self.matmul = P.MatMul(transpose_b=False) | |||||
| self.index = Tensor(np.ones([64, 64]), dtype=ms.int32) | |||||
| self.matmul_weight = Parameter(matmul_weight, "w1") | |||||
| self.axis = 0 | |||||
| def construct(self, x, b): | |||||
| out = self.gatherv2(x, self.index, self.axis) | |||||
| out = self.reshape(out, (64, -1)) | |||||
| out = self.matmul(out, self.matmul_weight) | |||||
| return out | |||||
| _w1 = Tensor(np.ones([4096, 32]), dtype=ms.float32) | |||||
| _x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||||
| def compile_net(net): | |||||
| context.set_context(save_graphs=True) | |||||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| train_net = TrainOneStepCell(net, optimizer) | |||||
| train_net.set_auto_parallel() | |||||
| _executor.compile(train_net, _x, _b) | |||||
| context.reset_auto_parallel_context() | |||||
| def test_reshape_skip_redistribution(): | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||||
| strategy1 = ((1, 8), (1, 1)) | |||||
| net = Net(_w1, strategy1) | |||||
| compile_net(net) | |||||