| @@ -125,6 +125,7 @@ REGISTER(GetNextInfo); | |||||
| REGISTER(NegInfo); | REGISTER(NegInfo); | ||||
| REGISTER(BatchMatMulInfo); | REGISTER(BatchMatMulInfo); | ||||
| REGISTER(ExpandDimsInfo); | REGISTER(ExpandDimsInfo); | ||||
| REGISTER(SqueezeInfo); | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <utility> | |||||
| #include "ir/value.h" | #include "ir/value.h" | ||||
| #include "parallel/auto_parallel/costmodel.h" | #include "parallel/auto_parallel/costmodel.h" | ||||
| @@ -544,5 +545,160 @@ Status ExpandDimsInfo::InferMirrorOps() { | |||||
| MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name(); | MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name(); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status SqueezeInfo::InferAxis(const ValueTuplePtr& value_tuple) { | |||||
| std::vector<int32_t> axis; | |||||
| auto axis_list = value_tuple->value(); | |||||
| if (inputs_shape_.empty()) { | |||||
| MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; | |||||
| return FAILED; | |||||
| } | |||||
| Shape input_shape = inputs_shape_.at(0); | |||||
| size_t input_size = input_shape.size(); | |||||
| // if axis tuple is empty, we should exclude the axis that the corresponding slice shape is 1. | |||||
| if (axis_list.empty()) { | |||||
| for (size_t i = 0; i < input_size; ++i) { | |||||
| if (input_shape[i] == 1) { | |||||
| axis.push_back(i); | |||||
| } | |||||
| } | |||||
| axis_ = MakeValue(axis)->cast<ValueTuplePtr>(); | |||||
| return SUCCESS; | |||||
| } | |||||
| // convert negative axis to positive. | |||||
| for (auto& dim : axis_list) { | |||||
| if (!dim->isa<Int32Imm>()) { | |||||
| MS_LOG(ERROR) << name_ << ": The type of axis is not int"; | |||||
| return FAILED; | |||||
| } | |||||
| int32_t dim_value = GetValue<int32_t>(dim); | |||||
| int32_t positive_value = (dim_value < 0) ? (dim_value + SizeToInt(input_size)) : dim_value; | |||||
| axis.push_back(positive_value); | |||||
| } | |||||
| axis_ = MakeValue(axis)->cast<ValueTuplePtr>(); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status SqueezeInfo::GetAttrs() { | |||||
| auto iter = attrs_.find(AXIS); | |||||
| if (iter == attrs_.end()) { | |||||
| MS_LOG(ERROR) << name_ << ": Can't find axis attribute."; | |||||
| return FAILED; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(iter->second); | |||||
| auto value_tuple = iter->second->cast<ValueTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_tuple); | |||||
| InferAxis(value_tuple); | |||||
| attrs_[AXIS] = axis_; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status SqueezeInfo::InferReplaceOps(const StrategyPtr& strategy) { | |||||
| Attr attr = std::make_pair(AXIS, axis_); | |||||
| OperatorAttrs attrs = {attr}; | |||||
| OperatorParams params; | |||||
| OperatorArgs args = std::make_pair(attrs, params); | |||||
| replace_op_ = {std::make_pair(SQUEEZE, args)}; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status SqueezeInfo::InferTensorMap() { | |||||
| // for example: if the shape of input is [32, 32, 1], and the axis is (2, ), | |||||
| // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1] | |||||
| std::vector<int32_t> input_tensor_map, output_tensor_map; | |||||
| if (inputs_shape_.empty()) { | |||||
| MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; | |||||
| return FAILED; | |||||
| } | |||||
| size_t size = inputs_shape_[0].size(); | |||||
| std::vector<int32_t> axis = GetValue<const std::vector<int>>(axis_); | |||||
| for (size_t i = 0; i < size; ++i) { | |||||
| size_t index = size - i - 1; | |||||
| auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); | |||||
| if (iter == axis.end()) { | |||||
| output_tensor_map.push_back(SizeToInt(index)); | |||||
| } | |||||
| input_tensor_map.push_back(SizeToInt(index)); | |||||
| } | |||||
| inputs_tensor_map_.push_back(input_tensor_map); | |||||
| outputs_tensor_map_.push_back(output_tensor_map); | |||||
| MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map) | |||||
| << ", and the tensor map of output is " << ShapeToString(output_tensor_map); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status SqueezeInfo::InferTensorInfo() { | |||||
| if (inputs_shape_.empty() || outputs_shape_.empty()) { | |||||
| MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty"; | |||||
| return FAILED; | |||||
| } | |||||
| if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { | |||||
| MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty"; | |||||
| return FAILED; | |||||
| } | |||||
| Shape input_shape = inputs_shape_[0]; | |||||
| Shape output_shape = outputs_shape_[0]; | |||||
| // infer slice shape | |||||
| Shapes inputs_slice_shape, outputs_slice_shape; | |||||
| Strategys inputs_strategy = strategy_->GetInputDim(); | |||||
| Dimensions output_strategy; | |||||
| std::vector<int32_t> axis = GetValue<const std::vector<int>>(axis_); | |||||
| for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { | |||||
| auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); | |||||
| if (iter == axis.end()) { | |||||
| output_strategy.push_back(inputs_strategy[0].at(i)); | |||||
| } | |||||
| } | |||||
| Strategys outputs_strategy = {output_strategy}; | |||||
| if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Infer slice shape failed"; | |||||
| return FAILED; | |||||
| } | |||||
| if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) { | |||||
| MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty"; | |||||
| return FAILED; | |||||
| } | |||||
| Shape input_slice_shape = inputs_slice_shape[0]; | |||||
| Shape output_slice_shape = outputs_slice_shape[0]; | |||||
| // infer tensor layout | |||||
| TensorLayout input_tensor_layout, output_tensor_layout; | |||||
| if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed"; | |||||
| return FAILED; | |||||
| } | |||||
| if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed"; | |||||
| return FAILED; | |||||
| } | |||||
| TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); | |||||
| TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); | |||||
| inputs_tensor_info_.push_back(input_tensor_info); | |||||
| outputs_tensor_info_.push_back(output_tensor_info); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status SqueezeInfo::Init(const StrategyPtr& strategy) { | |||||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Init failed."; | |||||
| } | |||||
| if (InferReplaceOps(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Infer replace ops failed"; | |||||
| } | |||||
| MS_LOG(INFO) << name_ << " : Init success."; | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -184,6 +184,25 @@ class ExpandDimsInfo : public ActivationOther { | |||||
| Strategys inputs_strategy_; | Strategys inputs_strategy_; | ||||
| Strategys outputs_strategy_; | Strategys outputs_strategy_; | ||||
| }; | }; | ||||
| class SqueezeInfo : public ActivationOther { | |||||
| public: | |||||
| SqueezeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||||
| const PrimitiveAttrs& attrs) | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| ~SqueezeInfo() override = default; | |||||
| protected: | |||||
| Status InferAxis(const ValueTuplePtr& value_tuple); | |||||
| Status GetAttrs() override; | |||||
| Status InferReplaceOps(const StrategyPtr& strategy); | |||||
| Status InferTensorMap() override; | |||||
| Status InferTensorInfo() override; | |||||
| Status Init(const StrategyPtr& strategy) override; | |||||
| private: | |||||
| ValueTuplePtr axis_; | |||||
| }; | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ACTIVATION_INFO_H_ | |||||
| #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ | |||||
| @@ -116,4 +116,4 @@ class AssignSubInfo : public ArithmeticBase { | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ARITHMETIC_INFO_H_ | |||||
| #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ | |||||
| @@ -53,4 +53,4 @@ class MaximumInfo : public ArithmeticBase { | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_COMPARISON_FUNCTION_INFO_H_ | |||||
| #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ | |||||
| @@ -65,4 +65,4 @@ class OneHotInfo : public OperatorInfo { | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ONEHOT_INFO_H_ | |||||
| #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ | |||||
| @@ -47,8 +47,8 @@ using mindspore::tensor::Tensor; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; | |||||
| const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; | |||||
| static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; | |||||
| static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; | |||||
| // g_RefMap, for CNode B input i is a RefKey[Parameter C], | // g_RefMap, for CNode B input i is a RefKey[Parameter C], | ||||
| // it will be one item in map with key: C, and value: (B, i) | // it will be one item in map with key: C, and value: (B, i) | ||||
| static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap; | static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap; | ||||
| @@ -1832,7 +1832,6 @@ void ParallelCommunication(const FuncGraphPtr& root, const std::vector<AnfNodePt | |||||
| if (cnode == loss_cnode) { | if (cnode == loss_cnode) { | ||||
| is_loss_cnode = true; | is_loss_cnode = true; | ||||
| } | } | ||||
| // insert forward ops | // insert forward ops | ||||
| InsertForwardOps(distribute_operator, cnode); | InsertForwardOps(distribute_operator, cnode); | ||||
| @@ -0,0 +1,79 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| from mindspore import context, Tensor, Parameter | |||||
| from mindspore.nn import Cell, TrainOneStepCell, Momentum | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.api import _executor | |||||
| class Net(Cell): | |||||
| def __init__(self, strategy1=None, strategy2=None, axis=()): | |||||
| super().__init__() | |||||
| self.squeeze = P.Squeeze(axis=axis).set_strategy(strategy1) | |||||
| self.mul = P.Mul().set_strategy(strategy2) | |||||
| def construct(self, x, b): | |||||
| out = self.squeeze(x) | |||||
| out = self.mul(out, b) | |||||
| return out | |||||
| _x = Tensor(np.ones([64, 1, 32, 1]), dtype=ms.float32) | |||||
| _b = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||||
| def compile(net): | |||||
| _executor.compile(net, _x, _b) | |||||
| context.reset_auto_parallel_context() | |||||
| def test_squeeze_data_parallel(): | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||||
| strategy1 = ((16, 1, 1, 1), ) | |||||
| strategy2 = ((16, 1), (16, 1)) | |||||
| net = Net(strategy1, strategy2) | |||||
| compile(net) | |||||
| def test_squeeze_model_parallel(): | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||||
| strategy1 = ((1, 1, 16, 1), ) | |||||
| strategy2 = ((1, 16), (1, 16)) | |||||
| net = Net(strategy1, strategy2) | |||||
| compile(net) | |||||
| def test_squeeze_specified_axis(): | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||||
| strategy1 = ((4, 1, 4, 1), ) | |||||
| strategy2 = ((8, 2), (8, 2)) | |||||
| net = Net(strategy1, strategy2, (1, 3)) | |||||
| compile(net) | |||||
| def test_squeeze_auto_parallel(): | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) | |||||
| net = Net() | |||||
| compile(net) | |||||
| def test_squeeze_repeat_calc(): | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||||
| strategy1 = ((1, 1, 8, 1), ) | |||||
| strategy2 = ((2, 8), (2, 8)) | |||||
| net = Net(strategy1, strategy2) | |||||
| compile(net) | |||||