| @@ -488,10 +488,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra | |||
| } | |||
| TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; | |||
| SetOutputLayout(next_in_tensor_info.tensor_layout()); | |||
| if (Init(nullptr) == FAILED) { | |||
| MS_LOG(DEBUG) << "Failure:operator reshape init failed"; | |||
| continue; | |||
| } | |||
| InferTensorInfoByLayout(); | |||
| SetCostForReshape(reshape_stra); | |||
| } | |||
| } | |||
| @@ -63,6 +63,14 @@ std::shared_ptr<ReshapeLayoutTransfer> RedistributionLayoutTransfer::UnifyDevice | |||
| if (unified_device_arrangement_ptr == nullptr) { | |||
| return nullptr; | |||
| } | |||
| Shape in_expand_shape; | |||
| Status status = ExpandShape(unified_device_arrangement_ptr->from_in().tensor_shape().array(), | |||
| unified_device_arrangement_ptr->to_in().tensor_shape().array(), &in_expand_shape); | |||
| if (status != Status::SUCCESS) { | |||
| MS_LOG(INFO) << "The shape of from and to cannot transfer by unify"; | |||
| unified_device_arrangement_ptr->SetExpandAble(false); | |||
| return unified_device_arrangement_ptr; | |||
| } | |||
| return unified_device_arrangement_ptr->UnifyDeviceArrangementAndTensorShape(); | |||
| } | |||
| } // namespace parallel | |||
| @@ -35,12 +35,15 @@ class ReshapeLayoutTransfer : public LayoutTransfer { | |||
| std::shared_ptr<ReshapeLayoutTransfer> ExpandFromTensorShapeAndExpandToDeviceArrangement( | |||
| const Arrangement &expand_shape) const; | |||
| std::shared_ptr<ReshapeLayoutTransfer> ExchangeFromAndTo() const; | |||
| bool ExpandAble() const { return is_expand_able_; } | |||
| bool FromTensorShapeCanBeExpandByTo() const; | |||
| bool ToTensorShapeCanBeExpandByFrom() const; | |||
| void SetExpandAble(const bool is_expand_able) { is_expand_able_ = is_expand_able; } | |||
| private: | |||
| Status CheckValidTransfer() override; | |||
| std::shared_ptr<Arrangement> ComputeExpandedFromTensorShapeByTo() const; | |||
| bool FromTensorShapeCanBeExpandByTo() const; | |||
| bool ToTensorShapeCanBeExpandByFrom() const; | |||
| bool is_expand_able_ = true; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -97,11 +97,11 @@ Status AccumulateProductReverseToShape(const Shape &shape_accum_reverse, Shape * | |||
| int64_t value = 1; | |||
| for (auto iter = shape_accum_reverse.end() - 1; iter >= shape_accum_reverse.begin(); --iter) { | |||
| if (*iter == 0) { | |||
| MS_LOG(ERROR) << "element of shape_accum should not be zero"; | |||
| MS_LOG(WARNING) << "element of shape_accum should not be zero"; | |||
| return Status::FAILED; | |||
| } | |||
| if ((*iter) % value != 0) { | |||
| MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; | |||
| MS_LOG(WARNING) << "shape_accum is not a accumulate product in ascending order"; | |||
| return Status::FAILED; | |||
| } | |||
| (void)shape->insert(shape->begin(), static_cast<int64_t>((*iter) / value)); | |||
| @@ -390,6 +390,15 @@ TensorLayout TensorLayout::SqueezeShape() const { | |||
| return out; | |||
| } | |||
| TensorLayout TensorLayout::TransferRepeatLayout() const { | |||
| Shape dev_mat(device_arrangement_.array()); | |||
| Shape tensor_map(tensor_map_.GetDimSize(), -1); | |||
| Shape tensor_shape(tensor_shape_.array()); | |||
| TensorLayout repeat; | |||
| repeat.InitFromVector(dev_mat, tensor_map, tensor_shape); | |||
| return repeat; | |||
| } | |||
| // Generate a totally shard tensor slice shape for parallel optimizer | |||
| Status TensorLayout::GenerateOptShardSliceShape() { | |||
| MS_LOG(INFO) << "layout for GetOptShardSliceShape is " << StandardToString(); | |||
| @@ -88,6 +88,8 @@ class TensorLayout { | |||
| TensorLayout SqueezeShape() const; | |||
| TensorLayout TransferRepeatLayout() const; | |||
| Status GenerateOptShardSliceShape(); | |||
| Shape opt_shard_slice_shape() { return opt_shard_slice_shape_; } | |||
| @@ -39,6 +39,42 @@ Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout & | |||
| return Status::SUCCESS; | |||
| } | |||
| RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorListUnExpand(bool is_cost_model) { | |||
| TensorLayout from_repeat = from_origin_.TransferRepeatLayout(); | |||
| TensorLayout to_repeat = to_origin_.TransferRepeatLayout(); | |||
| MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString(); | |||
| MS_LOG(DEBUG) << "reshape to_layout " << to_repeat.ToString(); | |||
| MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString(); | |||
| MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString(); | |||
| MS_LOG(DEBUG) << "reshape from_ " << from_.ToString(); | |||
| MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); | |||
| OperatorVector operator_vector; | |||
| OutPutInfoVector output_info_vector; | |||
| if (InferRedistribution(from_origin_, from_repeat, &operator_vector, &output_info_vector, is_cost_model) == | |||
| Status::FAILED) { | |||
| return nullptr; | |||
| } | |||
| if (from_repeat.slice_shape().array() != to_repeat.slice_shape().array()) { | |||
| reshape_flag_ = true; | |||
| ConstructOperator constructor; | |||
| constructor.UpdateTensorShape(from_repeat.slice_shape().array()); | |||
| Arrangement shape = to_repeat.slice_shape(); | |||
| MS_LOG(DEBUG) << "reshape " << shape.ToString(); | |||
| if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { | |||
| return nullptr; | |||
| } else { | |||
| (void)operator_vector.push_back(constructor.GetOperator()); | |||
| (void)output_info_vector.push_back(std::make_pair(false, 0)); | |||
| } | |||
| } | |||
| if (InferRedistribution(to_repeat, to_origin_, &operator_vector, &output_info_vector, is_cost_model) == | |||
| Status::FAILED) { | |||
| return nullptr; | |||
| } | |||
| return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>( | |||
| std::make_pair(operator_vector, output_info_vector)); | |||
| } | |||
| RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) { | |||
| // Step 1: Match device arrangement between from_ and to_ | |||
| RedistributionLayoutTransfer layout_transfer; | |||
| @@ -51,6 +87,10 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL | |||
| MS_LOG(ERROR) << "Infer tensor layout return nullptr!"; | |||
| return nullptr; | |||
| } | |||
| if (!ptr->ExpandAble()) { | |||
| expand_able_ = false; | |||
| return InferTensorRedistributionOperatorListUnExpand(is_cost_model); | |||
| } | |||
| TensorLayout from_layout = ptr->from_in(); | |||
| TensorLayout to_layout = ptr->to_in(); | |||
| MS_LOG(DEBUG) << "reshape from_layout " << from_layout.ToString(); | |||
| @@ -61,27 +101,17 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL | |||
| MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); | |||
| // Step 2: Infer redistribution and insert operators | |||
| RedistributionOperatorInfer operator_infer(construct_op_flag_); | |||
| if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) { | |||
| MS_LOG(ERROR) << "Init operatorInfer failed!"; | |||
| return nullptr; | |||
| } | |||
| OperatorVector operator_vector; | |||
| OutPutInfoVector output_info_vector; | |||
| if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) { | |||
| MS_LOG(ERROR) << "Infer redistribution failed!"; | |||
| if (InferRedistribution(from_layout, to_layout, &operator_vector, &output_info_vector, is_cost_model) != | |||
| Status::SUCCESS) { | |||
| return nullptr; | |||
| } else { | |||
| operator_vector = operator_infer.operator_vector(); | |||
| output_info_vector = operator_infer.output_info_vector(); | |||
| operator_list_ = operator_infer.operator_list(); | |||
| } | |||
| // Step 3: Infer reshape and insert operators | |||
| if (InferReshape(from_layout, to_layout, &operator_vector, &output_info_vector) != Status::SUCCESS) { | |||
| MS_LOG(ERROR) << "Construct Reshape operator failed!"; | |||
| return nullptr; | |||
| } | |||
| return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>( | |||
| std::make_pair(operator_vector, output_info_vector)); | |||
| } | |||
| @@ -136,6 +166,31 @@ Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const | |||
| return Status::SUCCESS; | |||
| } | |||
| Status TensorRedistribution::InferRedistribution(const TensorLayout &from_layout, const TensorLayout &to_layout, | |||
| OperatorVector *const operator_vector, | |||
| OutPutInfoVector *const output_info_vector, bool is_cost_model) { | |||
| RedistributionOperatorInfer operator_infer(construct_op_flag_); | |||
| if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) { | |||
| MS_LOG(ERROR) << "Init operatorInfer failed"; | |||
| return Status::FAILED; | |||
| } | |||
| if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) { | |||
| MS_LOG(ERROR) << "Infer redistribution failed"; | |||
| return Status::FAILED; | |||
| } else { | |||
| for (auto op : operator_infer.operator_vector()) { | |||
| operator_vector->insert(operator_vector->end(), op); | |||
| } | |||
| for (auto info : operator_infer.output_info_vector()) { | |||
| output_info_vector->insert(output_info_vector->end(), info); | |||
| } | |||
| for (auto opc : operator_infer.operator_list()) { | |||
| operator_list_.insert(operator_list_.end(), opc); | |||
| } | |||
| } | |||
| return Status::SUCCESS; | |||
| } | |||
| Status TensorRedistribution::ComputeCost() { | |||
| RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true); | |||
| if (redistribution_oplist_ptr == nullptr) { | |||
| @@ -162,8 +217,13 @@ Status TensorRedistribution::ComputeCost() { | |||
| } | |||
| } | |||
| if (reshape_flag()) { | |||
| Shape prev_slice_shape = from_.slice_shape().array(); | |||
| double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies<int>()); | |||
| Shape prev_shape; | |||
| if (expand_able_) { | |||
| prev_shape = from_.slice_shape().array(); | |||
| } else { | |||
| prev_shape = from_.tensor_shape().array(); | |||
| } | |||
| double prev_prod = std::accumulate(prev_shape.begin(), prev_shape.end(), 1, std::multiplies<int>()); | |||
| computation_cost_ += 2.0 * prev_prod; | |||
| memory_cost_ += 2.0 * prev_prod; | |||
| } | |||
| @@ -61,8 +61,12 @@ class TensorRedistribution { | |||
| private: | |||
| Status InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, | |||
| OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector); | |||
| Status InferRedistribution(const TensorLayout &from_layout, const TensorLayout &to_layout, | |||
| OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector, | |||
| bool is_cost_model); | |||
| Status ComputeConcatCost(double input_size, Shape attrs); | |||
| Status ComputePermuteCost(double input_size, Shape attrs); | |||
| RedistributionOpListPtr InferTensorRedistributionOperatorListUnExpand(bool is_cost_model = false); | |||
| TensorLayout from_origin_; | |||
| TensorLayout to_origin_; | |||
| TensorLayout from_; | |||
| @@ -84,6 +88,7 @@ class TensorRedistribution { | |||
| double memory_cost_; | |||
| bool construct_op_flag_; | |||
| bool keep_reshape_; | |||
| bool expand_able_ = true; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,206 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||
| grad_all = C.GradOperation(get_all=True) | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x): | |||
| predict = self.network(x) | |||
| return self.loss(predict) | |||
| class GradWrap(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x): | |||
| return grad_all(self.network)(x) | |||
| def test_reshape_unexpand(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.reshape = P.Reshape() | |||
| self.mul = P.Mul().shard(((1, 8), (1, 1, 8))) | |||
| self.mul_weight = Parameter(Tensor(np.ones([96, 128]), dtype=ms.float32), name="weight") | |||
| def construct(self, x): | |||
| weight = self.reshape(self.mul_weight, (1, 128, 96)) | |||
| out = self.mul(x, weight) | |||
| return out | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([128, 96]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||
| def test_reshape_unexpand_1(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.reshape = P.Reshape() | |||
| self.mul = P.Mul().shard(((1, 8), (1, 1, 8))) | |||
| self.mul_weight = Parameter(Tensor(np.ones([96, 128]), dtype=ms.float32), name="weight") | |||
| def construct(self, x): | |||
| weight = self.reshape(self.mul_weight, (1, 128, 96)) | |||
| out = self.mul(x, weight) | |||
| return out | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([128, 96]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||
| def test_reshape_unexpand_2(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.reshape = P.Reshape() | |||
| self.mul = P.Mul().shard(((1, 4, 2), (4, 2))) | |||
| self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") | |||
| def construct(self, data): | |||
| x = self.reshape(self.mul_weight, (1, 128, 96)) | |||
| out = self.mul(x, self.mul_weight) | |||
| return out | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([128, 96]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||
| def test_reshape_unexpand_3(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.reshape = P.Reshape() | |||
| self.relu1 = P.ReLU().shard(((4, 1),)) | |||
| self.relu2 = P.ReLU().shard(((1, 4),)) | |||
| def construct(self, data): | |||
| x = self.relu1(data) | |||
| x = self.reshape(x, (3, 4)) | |||
| x = self.relu2(x) | |||
| return x | |||
| size = 4 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([4, 3]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||
| def test_reshape_unexpand_4(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.reshape = P.Reshape() | |||
| self.relu1 = P.ReLU().shard(((4, 1),)) | |||
| self.relu2 = P.ReLU().shard(((1, 2, 2),)) | |||
| def construct(self, data): | |||
| x = self.relu1(data) | |||
| x = self.reshape(x, (3, 2, 2)) | |||
| x = self.relu2(x) | |||
| return x | |||
| size = 4 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([4, 3]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||
| def test_reshape_unexpand_5(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.reshape = P.Reshape() | |||
| self.relu1 = P.ReLU().shard(((2, 2, 1),)) | |||
| self.relu2 = P.ReLU().shard(((1, 4),)) | |||
| def construct(self, data): | |||
| x = self.relu1(data) | |||
| x = self.reshape(x, (3, 4)) | |||
| x = self.relu2(x) | |||
| return x | |||
| size = 4 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([2, 2, 3]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||
| def test_reshape_unexpand_6(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.reshape = P.Reshape() | |||
| self.relu1 = P.ReLU().shard(((2, 1),)) | |||
| self.relu2 = P.ReLU().shard(((1, 1, 4),)) | |||
| def construct(self, data): | |||
| x = self.relu1(data) | |||
| x = self.reshape(x, (1, 3, 4)) | |||
| x = self.relu2(x) | |||
| return x | |||
| size = 4 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([4, 3]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||