| @@ -44,7 +44,14 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { | |||
| auto device_arrangement = tensor_layout->device_arrangement().array(); | |||
| auto tensor_map = tensor_layout->tensor_map().array(); | |||
| auto slice_shape = tensor_layout->slice_shape().array(); | |||
| std::vector<std::vector<int32_t>> layout = {device_arrangement, tensor_map, slice_shape}; | |||
| int32_t _field_size = tensor_layout->get_field_size(); | |||
| std::vector<int32_t> field_size; | |||
| if (_field_size != 0) { | |||
| field_size.push_back(_field_size); | |||
| } else { | |||
| field_size = {0}; | |||
| } | |||
| std::vector<std::vector<int32_t>> layout = {device_arrangement, tensor_map, slice_shape, field_size}; | |||
| dict[py::str(name)] = layout; | |||
| MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); | |||
| } | |||
| @@ -105,6 +105,17 @@ Status MatMulBase::GetAttrs() { | |||
| } | |||
| } | |||
| auto field_size_iter = attrs_.find(FIELD_SIZE); | |||
| if (field_size_iter != attrs_.end()) { | |||
| MS_EXCEPTION_IF_NULL(field_size_iter->second); | |||
| if (field_size_iter->second->isa<Int32Imm>()) { | |||
| field_size_ = field_size_iter->second->cast<Int32ImmPtr>()->value(); | |||
| } else { | |||
| MS_LOG(ERROR) << name_ << " : The value of field_size is not int."; | |||
| return FAILED; | |||
| } | |||
| } | |||
| // infer inputs dimension size | |||
| if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) { | |||
| MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; | |||
| @@ -346,6 +357,10 @@ Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts | |||
| return FAILED; | |||
| } | |||
| if (field_size_ != 0) { | |||
| mat_b_layout.set_field_size(field_size_); | |||
| } | |||
| inputs_layout->push_back(mat_a_layout); | |||
| inputs_layout->push_back(mat_b_layout); | |||
| outputs_layout->push_back(output_layout); | |||
| @@ -62,6 +62,7 @@ class MatMulBase : public OperatorInfo { | |||
| bool transpose_a_ = false; | |||
| bool transpose_b_ = false; | |||
| bool forward_reduce_scatter_ = false; | |||
| int32_t field_size_ = 0; | |||
| size_t mat_a_dimension_ = 0; | |||
| size_t mat_b_dimension_ = 0; | |||
| }; | |||
| @@ -100,6 +100,7 @@ constexpr char CONCAT_DIM[] = "concat_dim"; | |||
| constexpr char FORWARD[] = "forward"; | |||
| constexpr char BACKWARD[] = "backward"; | |||
| constexpr char REDISTRIBUTION[] = "redistribution"; | |||
| constexpr char SKIP_REDISTRIBUTION[] = "skip_redistribution"; | |||
| constexpr char REPLACE[] = "replace"; | |||
| constexpr char CONNSYMBOL[] = "/"; | |||
| constexpr char INSTANCE_NAME[] = "instance_name"; | |||
| @@ -131,6 +132,7 @@ constexpr char FORWARD_OP[] = "forward_op"; | |||
| constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; | |||
| constexpr char DARA_PARALLEL[] = "data_parallel"; | |||
| constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; | |||
| constexpr char FIELD_SIZE[] = "field_size"; | |||
| constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; | |||
| constexpr char DEVICE[] = "Device"; | |||
| @@ -18,6 +18,7 @@ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "frontend/parallel/device_manager.h" | |||
| #include "frontend/parallel/device_matrix.h" | |||
| @@ -145,17 +146,23 @@ Status ReshapeInfo::ComputeReplaceOp() { | |||
| MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString(); | |||
| MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString(); | |||
| MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size(); | |||
| RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); | |||
| if (redistribution_oplist_ptr == nullptr) { | |||
| if (is_generating_costs_) { | |||
| MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed."; | |||
| } else { | |||
| MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; | |||
| if (is_skip_) { | |||
| ConstructOperator constructor; | |||
| replace_op_ = constructor.SkipRedisReshapeOP(output_layout_.slice_shape().array()); | |||
| replace_op_info_.clear(); | |||
| } else { | |||
| RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); | |||
| if (redistribution_oplist_ptr == nullptr) { | |||
| if (is_generating_costs_) { | |||
| MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed."; | |||
| } else { | |||
| MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; | |||
| } | |||
| return FAILED; | |||
| } | |||
| return FAILED; | |||
| replace_op_ = redistribution_oplist_ptr->first; | |||
| replace_op_info_ = redistribution_oplist_ptr->second; | |||
| } | |||
| replace_op_ = redistribution_oplist_ptr->first; | |||
| replace_op_info_ = redistribution_oplist_ptr->second; | |||
| MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size(); | |||
| return SUCCESS; | |||
| } | |||
| @@ -255,6 +262,19 @@ Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayout | |||
| } | |||
| Status ReshapeInfo::InferTensorInfo() { | |||
| // skip reshape infer if skip_redistribution is true | |||
| if (is_skip_) { | |||
| TensorLayout layout; | |||
| Shape shape; | |||
| Shape slice_shape; | |||
| layout.set_skip_redistribution(true); | |||
| TensorInfo tensor_info_in(layout, shape, slice_shape); | |||
| inputs_tensor_info_.push_back(tensor_info_in); | |||
| outputs_tensor_info_.push_back(tensor_info_in); | |||
| MS_LOG(DEBUG) << name() << "skip redistribution reshape InferTensorInfo"; | |||
| return SUCCESS; | |||
| } | |||
| Shapes inputs_slice_shape, outputs_slice_shape; | |||
| Strategys inputs_strategy = strategy_->GetInputDim(); | |||
| Strategys outputs_strategy = GetOutputsStrategy(); | |||
| @@ -316,6 +336,16 @@ Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const l | |||
| } | |||
| Status ReshapeInfo::Init(const StrategyPtr &strategy) { | |||
| auto reshape_skip_redis_iter = attrs_.find(SKIP_REDISTRIBUTION); | |||
| if (reshape_skip_redis_iter != attrs_.end()) { | |||
| MS_EXCEPTION_IF_NULL(reshape_skip_redis_iter->second); | |||
| if (!reshape_skip_redis_iter->second->isa<BoolImm>()) { | |||
| MS_LOG(ERROR) << name_ << ": skip_redistribution is not a bool."; | |||
| return FAILED; | |||
| } | |||
| is_skip_ = reshape_skip_redis_iter->second->cast<BoolImmPtr>()->value(); | |||
| } | |||
| ResetQueueMember(); | |||
| device_number(strategy); | |||
| if (strategy) { | |||
| @@ -98,6 +98,7 @@ class ReshapeInfo : public OperatorInfo { | |||
| bool input_layout_set_flag_; | |||
| bool output_layout_set_flag_; | |||
| bool is_generating_costs_; | |||
| bool is_skip_ = false; | |||
| std::string pre_operator_name_; | |||
| std::string next_operator_name_; | |||
| }; | |||
| @@ -302,16 +302,26 @@ void Redistribution(const std::pair<AnfNodePtr, int> &node_pair, const OperatorI | |||
| MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString(); | |||
| // extract tensor layout in and out | |||
| if (distribute_operator->outputs_tensor_info().empty()) { | |||
| MS_LOG(EXCEPTION) << "Failure:pre_node's tensorinfo_in is empty"; | |||
| MS_LOG(WARNING) << "pre_node's tensorinfo_in is empty, operator name is " << distribute_operator->name(); | |||
| return; | |||
| } | |||
| if (IntToSize(index - 1) >= next_distribute_operator->inputs_tensor_info().size()) { | |||
| MS_LOG(EXCEPTION) << "The index is out of range, the index is " << index - 1 << ", the vector size is " | |||
| << next_distribute_operator->inputs_tensor_info().size(); | |||
| MS_LOG(WARNING) << "The index is out of range, the index is " << index - 1 << ", the vector size is " | |||
| << next_distribute_operator->inputs_tensor_info().size() << "next operator name is " | |||
| << next_distribute_operator->name(); | |||
| return; | |||
| } | |||
| TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[IntToSize(index - 1)]; | |||
| TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); | |||
| TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator); | |||
| if (tensorlayout_in.skip_redistribution() || tensorlayout_out.skip_redistribution()) { | |||
| MS_LOG(INFO) << "skip the reshape redistribution, operator name is" << distribute_operator->name() | |||
| << "next distribute operator, operator name is" << next_distribute_operator->name(); | |||
| return; | |||
| } | |||
| if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) { | |||
| MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name; | |||
| MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " | |||
| @@ -28,6 +28,19 @@ Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix | |||
| return Status::SUCCESS; | |||
| } | |||
| // skip redistribution for reshape operator | |||
| OperatorVector ConstructOperator::SkipRedisReshapeOP(Shape shape) { | |||
| OperatorAttrs attrs; | |||
| ValuePtr param_value = MakeValue(shape); | |||
| Attr param = std::make_pair(SHAPE, param_value); | |||
| OperatorParams params = {std::make_pair(param, 2)}; | |||
| OperatorArgs args = std::make_pair(attrs, params); | |||
| Operator op = std::make_pair(RESHAPE, args); | |||
| OperatorVector opvector; | |||
| opvector.push_back(op); | |||
| return opvector; | |||
| } | |||
| Status ConstructOperator::ReshapeOP(Shape shape) { | |||
| int32_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); | |||
| int32_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies<int>()); | |||
| @@ -35,6 +35,7 @@ class ConstructOperator { | |||
| ConstructOperator() : dev_size_(0) {} | |||
| ~ConstructOperator() = default; | |||
| Status Init(const RankList &dev_list, const Shape &dev_matrix_shape); | |||
| OperatorVector SkipRedisReshapeOP(Shape shape); | |||
| Status ReshapeOP(Shape shape); | |||
| Status StridedSliceOP(Args args); | |||
| Status AllGatherOP(int32_t dev_dim); | |||
| @@ -41,6 +41,14 @@ class TensorLayout { | |||
| Status InitFromVector(const std::vector<int32_t> &device_arrangement, const std::vector<int32_t> &tensor_map, | |||
| const std::vector<int32_t> &tensor_shape); | |||
| bool skip_redistribution() const { return skip_redistribution_; } | |||
| void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; } | |||
| int32_t get_field_size() const { return field_size_; } | |||
| void set_field_size(int32_t field_size) { field_size_ = field_size; } | |||
| Arrangement device_arrangement() const { return device_arrangement_; } | |||
| Map tensor_map() const { return tensor_map_; } | |||
| @@ -92,6 +100,8 @@ class TensorLayout { | |||
| Arrangement device_arrangement_; | |||
| Map tensor_map_; | |||
| Arrangement tensor_shape_; | |||
| bool skip_redistribution_ = false; | |||
| int32_t field_size_ = 0; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -247,8 +247,8 @@ class Parameter: | |||
| if not isinstance(layout, list): | |||
| raise TypeError("The layout should be list! layout is {}." | |||
| .format(layout)) | |||
| if len(layout) != 3: | |||
| raise ValueError("The length of layout must be 3! layout is {}." | |||
| if len(layout) < 3: | |||
| raise ValueError("The length of layout must be larger than 3! layout is {}." | |||
| .format(layout)) | |||
| slice_index = int(_get_slice_index(layout[0], layout[1])) | |||
| self.default_input = self.init_mode.to_tensor(slice_index, layout[2]) | |||
| @@ -229,8 +229,8 @@ def _load_tensor_by_layout(tensor, layout): | |||
| """ | |||
| if not isinstance(layout, list): | |||
| raise TypeError("The layout should be list! layout is {}".format(layout)) | |||
| if len(layout) != 3: | |||
| raise ValueError("The length of layout must be 3! layout is {}".format(layout)) | |||
| if len(layout) < 3: | |||
| raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout)) | |||
| dev_mat = layout[0] | |||
| tensor_map = layout[1] | |||
| if tensor.size() == 1: | |||
| @@ -290,3 +290,37 @@ def _reshape_param_data(param_data, dev_mat, tensor_map): | |||
| tensor_slices_new = tensor_slices_new_inner | |||
| return Tensor(tensor_slices_new[0]) | |||
| def _reshape_param_data_with_weight(param_data, dev_mat, field_size): | |||
| """ | |||
| Combine param slice by the device matrix, used in model parallel scenario. | |||
| Args: | |||
| param_data (Tensor): The tensor to be reshaped and rearrangement, | |||
| generated from all the device from AllGatherParamNet. | |||
| dev_mat (list): The device matrix of devices. | |||
| Returns: | |||
| Tensor, the combined tensor which with the whole data value. | |||
| Examples: | |||
| >>> param_data = _allgather_param_net(param_data) | |||
| >>> dev_mat = [2, 2] | |||
| >>> field_size = [39] | |||
| >>> tensor = _reshape_param_data_with_weight(param_data, dev_mat, field_size) | |||
| """ | |||
| device_count = 1 | |||
| for dim in dev_mat: | |||
| device_count *= dim | |||
| tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0) | |||
| tensor_slices_col = [] | |||
| for i in range(len(tensor_slices[0][0])): | |||
| tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size[0], -1) | |||
| for j in range(1, device_count): | |||
| tensor_slices_new = np.concatenate((tensor_slices_new,\ | |||
| np.array(tensor_slices[j][:, i]).reshape(field_size[0], -1)), axis=1) | |||
| tensor_slices_col.append(tensor_slices_new) | |||
| new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1) | |||
| for i in range(1, len(tensor_slices_col)): | |||
| new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1) | |||
| return Tensor(new_tensor) | |||
| @@ -359,14 +359,17 @@ def _get_merged_param_data(net, param_name, param_data): | |||
| dev_mat = layout[0] | |||
| tensor_map = layout[1] | |||
| field_size = layout[3] | |||
| from mindspore.parallel._cell_wrapper import get_allgather_cell | |||
| from mindspore.parallel._tensor import _reshape_param_data | |||
| from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight | |||
| # while any dim is not equal to -1, means param is splited and needs to be merged | |||
| for dim in tensor_map: | |||
| if dim != -1: | |||
| allgather_net = get_allgather_cell() | |||
| param_data = allgather_net(param_data) | |||
| if field_size[0]: | |||
| return _reshape_param_data_with_weight(param_data, dev_mat, field_size) | |||
| return _reshape_param_data(param_data, dev_mat, tensor_map) | |||
| return param_data | |||
| @@ -49,8 +49,8 @@ def test_get_parameter_layout(): | |||
| net.set_auto_parallel() | |||
| exe = me._executor | |||
| exe.compile(net, x, phase='train', auto_parallel_mode=True) | |||
| x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1] | |||
| weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1] | |||
| x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1] | |||
| weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1] | |||
| expect_dict = {'x': x_layout, 'w1': weight_layout} | |||
| # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut | |||
| assert net.parameter_layout_dict == expect_dict | |||
| @@ -0,0 +1,58 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import mindspore as ms | |||
| from mindspore import context, Tensor, Parameter | |||
| from mindspore.common.api import _executor | |||
| from mindspore.nn import Cell, TrainOneStepCell, Momentum | |||
| from mindspore.ops import operations as P | |||
| class Net(Cell): | |||
| def __init__(self, matmul_weight, strategy1=None): | |||
| super().__init__() | |||
| self.gatherv2 = P.GatherV2().set_strategy(strategy1) | |||
| self.reshape = P.Reshape().add_prim_attr("skip_redistribution", True) | |||
| self.matmul = P.MatMul(transpose_b=False) | |||
| self.index = Tensor(np.ones([64, 64]), dtype=ms.int32) | |||
| self.matmul_weight = Parameter(matmul_weight, "w1") | |||
| self.axis = 0 | |||
| def construct(self, x, b): | |||
| out = self.gatherv2(x, self.index, self.axis) | |||
| out = self.reshape(out, (64, -1)) | |||
| out = self.matmul(out, self.matmul_weight) | |||
| return out | |||
| _w1 = Tensor(np.ones([4096, 32]), dtype=ms.float32) | |||
| _x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| def compile_net(net): | |||
| context.set_context(save_graphs=True) | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| train_net = TrainOneStepCell(net, optimizer) | |||
| train_net.set_auto_parallel() | |||
| _executor.compile(train_net, _x, _b) | |||
| context.reset_auto_parallel_context() | |||
| def test_reshape_skip_redistribution(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((1, 8), (1, 1)) | |||
| net = Net(_w1, strategy1) | |||
| compile_net(net) | |||