Merge pull request !23544 from zhuyuxiao/mastertags/v1.6.0
| @@ -141,7 +141,6 @@ REGISTER(ReLU6Info); | |||
| REGISTER(ReLUV2Info); | |||
| REGISTER(SoftplusInfo); | |||
| REGISTER(SoftsignInfo); | |||
| REGISTER(GatherInfo); | |||
| REGISTER(SparseGatherV2Info); | |||
| REGISTER(SqrtInfo); | |||
| REGISTER(SigmoidInfo); | |||
| @@ -181,7 +180,7 @@ REGISTER(UniformCandidateSamplerInfo); | |||
| REGISTER(UnsortedSegmentSumInfo); | |||
| REGISTER(UnsortedSegmentMinInfo); | |||
| REGISTER(UnsortedSegmentMaxInfo); | |||
| REGISTER(GatherPInfo); | |||
| REGISTER(GatherInfo); | |||
| REGISTER(EmbeddingLookupInfo); | |||
| REGISTER(TileInfo); | |||
| REGISTER(BroadcastToInfo); | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/parallel/ops_info/gather_v2_p_info.h" | |||
| #include "frontend/parallel/ops_info/gather_info.h" | |||
| #include <vector> | |||
| #include <numeric> | |||
| @@ -32,7 +32,7 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| Status GatherPInfo::GetManualSplitWithoutOffsetAttr() { | |||
| Status GatherInfo::GetManualSplitWithoutOffsetAttr() { | |||
| auto manual_split_without_offset_iter = attrs_.find("manual_split"); | |||
| if (manual_split_without_offset_iter != attrs_.end()) { | |||
| manual_split_ = true; | |||
| @@ -68,7 +68,7 @@ Status GatherPInfo::GetManualSplitWithoutOffsetAttr() { | |||
| return SUCCESS; | |||
| } | |||
| Status GatherPInfo::GetManualSplitAttr() { | |||
| Status GatherInfo::GetManualSplitAttr() { | |||
| auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset"); | |||
| if (manual_split_with_offset_iter != attrs_.end()) { | |||
| manual_split_ = true; | |||
| @@ -118,7 +118,7 @@ Status GatherPInfo::GetManualSplitAttr() { | |||
| return SUCCESS; | |||
| } | |||
| Status GatherPInfo::GetAttrs() { | |||
| Status GatherInfo::GetAttrs() { | |||
| // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis. | |||
| if (target_ != CPU) { | |||
| if (input_value_.at(2) == nullptr) { | |||
| @@ -170,7 +170,7 @@ Status GatherPInfo::GetAttrs() { | |||
| return SUCCESS; | |||
| } | |||
| Status GatherPInfo::CheckManualSplit(const Strategys &strategy) { | |||
| Status GatherInfo::CheckManualSplit(const Strategys &strategy) { | |||
| if (strategy.size() != 2) { | |||
| MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size(); | |||
| return FAILED; | |||
| @@ -226,7 +226,7 @@ Status GatherPInfo::CheckManualSplit(const Strategys &strategy) { | |||
| return SUCCESS; | |||
| } | |||
| Status GatherPInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) { | |||
| Status GatherInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) { | |||
| auto param_strategy = strategy->GetInputDim().at(0); | |||
| auto index_strategy = strategy->GetInputDim().at(1); | |||
| // param_strategy(axis) != 1, index can't be split | |||
| @@ -255,7 +255,7 @@ Status GatherPInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) { | |||
| // return true: axis is 0, and split the first dimension of parameter and the first dimension of indices | |||
| // otherwise return false | |||
| bool GatherPInfo::ShardBatchAndAxis(const Strategys &strategy) const { | |||
| bool GatherInfo::ShardBatchAndAxis(const Strategys &strategy) const { | |||
| if (axis_ != 0) { | |||
| return false; | |||
| } | |||
| @@ -285,7 +285,7 @@ bool GatherPInfo::ShardBatchAndAxis(const Strategys &strategy) const { | |||
| return true; | |||
| } | |||
| void GatherPInfo::SetAttribute(const StrategyPtr &strategy) { | |||
| void GatherInfo::SetAttribute(const StrategyPtr &strategy) { | |||
| auto param_strategy = strategy->GetInputDim().at(0); | |||
| // axis=0, index_shape(0)%param_strategy(0) must be 0 | |||
| Shape index_shape = inputs_shape_.at(1); | |||
| @@ -312,7 +312,7 @@ void GatherPInfo::SetAttribute(const StrategyPtr &strategy) { | |||
| MS_LOG(INFO) << "Set repeated_num_in_dev_matrix_right for gather to " << repeated_num_in_dev_matrix_right_; | |||
| } | |||
| Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| @@ -373,7 +373,7 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| return SUCCESS; | |||
| } | |||
| Status GatherPInfo::InferMirrorOps() { | |||
| Status GatherInfo::InferMirrorOps() { | |||
| // There is no mirror operators for manual split | |||
| if (manual_split_) { | |||
| return SUCCESS; | |||
| @@ -403,7 +403,7 @@ Status GatherPInfo::InferMirrorOps() { | |||
| return SUCCESS; | |||
| } | |||
| Status GatherPInfo::InferDevMatrixShape() { | |||
| Status GatherInfo::InferDevMatrixShape() { | |||
| dev_matrix_shape_.clear(); | |||
| out_dev_matrix_shape_.clear(); | |||
| // infer input dev_matrix_shape | |||
| @@ -460,7 +460,7 @@ Status GatherPInfo::InferDevMatrixShape() { | |||
| return SUCCESS; | |||
| } | |||
| void GatherPInfo::InferInputsTensorMap() { | |||
| void GatherInfo::InferInputsTensorMap() { | |||
| // infer input tensor map | |||
| // param_strategy(axis) is not 1 | |||
| size_t param_size = inputs_shape_.at(0).size(); | |||
| @@ -487,7 +487,7 @@ void GatherPInfo::InferInputsTensorMap() { | |||
| inputs_tensor_map_.emplace_back(std::move(tensor_map_index)); | |||
| } | |||
| void GatherPInfo::InferOutputsTensorMap() { | |||
| void GatherInfo::InferOutputsTensorMap() { | |||
| // infer output tensor map | |||
| size_t param_size = inputs_shape_.at(0).size(); | |||
| size_t index_size = inputs_shape_.at(1).size(); | |||
| @@ -534,7 +534,7 @@ void GatherPInfo::InferOutputsTensorMap() { | |||
| (void)outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); | |||
| } | |||
| Status GatherPInfo::InferTensorMap() { | |||
| Status GatherInfo::InferTensorMap() { | |||
| if (manual_split_) { | |||
| Shape param_map = {1, 0}; | |||
| Shape indices_map = {-1, 1}; | |||
| @@ -560,7 +560,7 @@ Status GatherPInfo::InferTensorMap() { | |||
| return SUCCESS; | |||
| } | |||
| Status GatherPInfo::InferTensorInfo() { | |||
| Status GatherInfo::InferTensorInfo() { | |||
| // infer tensor shape | |||
| Shape input_shape = inputs_shape_.at(0); | |||
| Shape input_index_shape = inputs_shape_.at(1); | |||
| @@ -593,7 +593,7 @@ Status GatherPInfo::InferTensorInfo() { | |||
| return SUCCESS; | |||
| } | |||
| Status GatherPInfo::InferBias() { | |||
| Status GatherInfo::InferBias() { | |||
| CheckGlobalDeviceManager(); | |||
| int64_t rank = g_device_manager->rank_index_in_stage(); | |||
| auto input_shape = inputs_shape_.at(0); | |||
| @@ -656,7 +656,7 @@ Status GatherPInfo::InferBias() { | |||
| return FAILED; | |||
| } | |||
| Status GatherPInfo::InferOffset() { | |||
| Status GatherInfo::InferOffset() { | |||
| CheckGlobalDeviceManager(); | |||
| size_t rank = LongToSize(g_device_manager->rank_index_in_stage()); | |||
| @@ -677,7 +677,7 @@ Status GatherPInfo::InferOffset() { | |||
| return FAILED; | |||
| } | |||
| Status GatherPInfo::InferGroup() { | |||
| Status GatherInfo::InferGroup() { | |||
| size_t dim = LongToSize(axis_); | |||
| int64_t rank = g_device_manager->global_rank(); | |||
| @@ -708,7 +708,7 @@ Status GatherPInfo::InferGroup() { | |||
| return SUCCESS; | |||
| } | |||
| Status GatherPInfo::InferForwardCommunication() { | |||
| Status GatherInfo::InferForwardCommunication() { | |||
| if (manual_split_) { | |||
| return SUCCESS; | |||
| } | |||
| @@ -745,7 +745,7 @@ Status GatherPInfo::InferForwardCommunication() { | |||
| return SUCCESS; | |||
| } | |||
| Status GatherPInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| Status GatherInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| GenerateGraph gen_g = GenerateGraph(attrs_); | |||
| if (gen_g.Init(cnode) != SUCCESS) { | |||
| MS_LOG(ERROR) << "GenerateGraph Init failed"; | |||
| @@ -805,7 +805,7 @@ Status GatherPInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| return SUCCESS; | |||
| } | |||
| ReplaceGraphPtr GatherPInfo::replace_graph(const CNodePtr &cnode) { | |||
| ReplaceGraphPtr GatherInfo::replace_graph(const CNodePtr &cnode) { | |||
| if (manual_split_ && target_ != CPU) { | |||
| if (ComputeReplaceGraph(cnode) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; | |||
| @@ -824,7 +824,7 @@ ReplaceGraphPtr GatherPInfo::replace_graph(const CNodePtr &cnode) { | |||
| return replace_graph_; | |||
| } | |||
| Status GatherPInfo::ComputeReplaceOp() { | |||
| Status GatherInfo::ComputeReplaceOp() { | |||
| int64_t bias = 0; | |||
| if (manual_split_) { | |||
| if (InferOffset() != SUCCESS) { | |||
| @@ -851,7 +851,7 @@ Status GatherPInfo::ComputeReplaceOp() { | |||
| return SUCCESS; | |||
| } | |||
| Status GatherPInfo::Init(const StrategyPtr &strategy) { | |||
| Status GatherInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| @@ -864,7 +864,7 @@ Status GatherPInfo::Init(const StrategyPtr &strategy) { | |||
| return SUCCESS; | |||
| } | |||
| Status GatherPInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| Status GatherInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| if (is_auto_parallel_) { | |||
| MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; | |||
| @@ -882,9 +882,9 @@ Status GatherPInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| return SUCCESS; | |||
| } | |||
| Status GatherPInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||
| Status GatherInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||
| std::vector<StrategyPtr> GatherPInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| std::vector<StrategyPtr> GatherInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| if (manual_split_) { | |||
| MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to search strategy"; | |||
| } | |||
| @@ -900,7 +900,7 @@ std::vector<StrategyPtr> GatherPInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| return sp_vector; | |||
| } | |||
| std::shared_ptr<Strategys> GatherPInfo::GenerateBatchStrategies() { | |||
| std::shared_ptr<Strategys> GatherInfo::GenerateBatchStrategies() { | |||
| if (GetAttrs() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; | |||
| } | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| @@ -29,17 +29,17 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| class GatherPInfo : public OperatorInfo { | |||
| class GatherInfo : public OperatorInfo { | |||
| public: | |||
| GatherPInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2) | |||
| GatherInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()), | |||
| axis_(0), | |||
| bias_(0), | |||
| index_offset_(0), | |||
| slice_size_(0), | |||
| replace_op_name_(replace_op_name) {} | |||
| ~GatherPInfo() override = default; | |||
| ~GatherInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| @@ -88,21 +88,21 @@ class GatherPInfo : public OperatorInfo { | |||
| std::vector<int64_t> index_offsets_; | |||
| }; | |||
| class SparseGatherV2Info : public GatherPInfo { | |||
| class SparseGatherV2Info : public GatherInfo { | |||
| public: | |||
| SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs, const std::string &replace_op_name = SPARSE_GATHERV2) | |||
| : GatherPInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {} | |||
| : GatherInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {} | |||
| ~SparseGatherV2Info() override = default; | |||
| }; | |||
| class EmbeddingLookupInfo : public GatherPInfo { | |||
| class EmbeddingLookupInfo : public GatherInfo { | |||
| public: | |||
| EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : GatherPInfo(name, inputs_shape, outputs_shape, attrs) {} | |||
| : GatherInfo(name, inputs_shape, outputs_shape, attrs) {} | |||
| ~EmbeddingLookupInfo() override = default; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_ | |||
| @@ -1,318 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/parallel/ops_info/gather_v2_info.h" | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "ir/tensor.h" | |||
| #include "ir/value.h" | |||
| #include "frontend/parallel/auto_parallel/costmodel.h" | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/graph_util/generate_graph.h" | |||
| #include "frontend/parallel/strategy.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| Status GatherInfo::GetAttrs() { | |||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | |||
| MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size(); | |||
| return FAILED; | |||
| } | |||
| if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { | |||
| MS_LOG(ERROR) << name_ << ": outputs shape size must be 1, but is " << outputs_shape_.size(); | |||
| return FAILED; | |||
| } | |||
| if (input_value_.size() != GATHER_V2_INPUTS_VALUE_SIZE) { | |||
| MS_LOG(ERROR) << name_ << ": input value size must be 3, but is " << input_value_.size(); | |||
| return FAILED; | |||
| } | |||
| // the second input is the index tensor | |||
| // the third input is the axis, is a ValueNode | |||
| if (input_value_.at(2) == nullptr) { | |||
| MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; | |||
| return FAILED; | |||
| } | |||
| if (inputs_shape_.at(0).size() == 0) { | |||
| MS_LOG(ERROR) << name_ << ": input can not be a scalar!"; | |||
| return FAILED; | |||
| } | |||
| int64_t axis = GetValue<int64_t>(input_value_.at(2)); | |||
| if (axis >= SizeToLong(inputs_shape_.at(0).size()) || axis < -SizeToLong(inputs_shape_.at(0).size())) { | |||
| MS_LOG(ERROR) << "Axis is " << axis << ", not in [-" << inputs_shape_.at(0).size() << ", " | |||
| << inputs_shape_.at(0).size() << ")."; | |||
| } | |||
| if (axis < 0) { | |||
| axis += SizeToLong(inputs_shape_[0].size()); | |||
| } | |||
| axis_ = axis; | |||
| index_size_ = inputs_shape_.at(1).size(); | |||
| return SUCCESS; | |||
| } | |||
| Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | |||
| MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | |||
| << inputs_shape_.size(); | |||
| return FAILED; | |||
| } | |||
| if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { | |||
| MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " | |||
| << outputs_shape_.size(); | |||
| return FAILED; | |||
| } | |||
| // Only strategy of the first input should be set. | |||
| if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | |||
| return FAILED; | |||
| } | |||
| axis_strategy_ = strategy->GetInputDim().at(0).at(LongToSize(axis_)); | |||
| if (index_size_ != 1 && axis_strategy_ != 1) { | |||
| MS_LOG(ERROR) << name_ | |||
| << ": Invalid strategy. If the index is a scalar or a more than 1 dimension vector, the strategy " | |||
| "corresponding to axis must be 1, but is " | |||
| << axis_strategy_; | |||
| return FAILED; | |||
| } | |||
| if (index_size_ == 1 && axis_strategy_ != 1 && inputs_shape_.at(1).at(0) % axis_strategy_ != 0) { | |||
| MS_LOG(ERROR) << name_ | |||
| << ": Invalid strategy. The first dimension of index can not be divided by strategy corresponding to " | |||
| "axis. The first dimension of index is " | |||
| << inputs_shape_.at(1).at(0) << " strategy corresponding to axis is " << axis_strategy_; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status GatherInfo::InferDevMatrixShape() { | |||
| Strategys stra = strategy_->GetInputDim(); | |||
| dev_matrix_shape_ = stra.at(0); | |||
| return SUCCESS; | |||
| } | |||
| // If index is a scalar, output dimension is input dimension minus 1; | |||
| // If index is a n dimension tensor, output dimension is input dimension plus (n - 1). | |||
| // Tensor map dimension is equal to the corresponding input and output dimension. | |||
| // If index's dimension is more than 1, we insert -1 for the output tensor map. | |||
| Status GatherInfo::InferTensorMap() { | |||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | |||
| MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | |||
| << inputs_shape_.size(); | |||
| return FAILED; | |||
| } | |||
| if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { | |||
| MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " | |||
| << outputs_shape_.size(); | |||
| return FAILED; | |||
| } | |||
| Shape tensor_map_in; | |||
| Shape tensor_map_out; | |||
| size_t size = inputs_shape_.at(0).size(); | |||
| // such as 4: tensor_map_index [3,2,1,0] | |||
| for (size_t i = 0; i < size; ++i) { | |||
| tensor_map_in.push_back(SizeToLong(size - i - 1)); | |||
| tensor_map_out.push_back(SizeToLong(size - i - 1)); | |||
| } | |||
| if (index_size_ == 0) { | |||
| (void)tensor_map_out.erase(tensor_map_out.begin() + axis_); | |||
| } else if (index_size_ > 1) { | |||
| (void)tensor_map_out.insert(tensor_map_out.begin() + axis_, index_size_ - 1, -1); | |||
| } | |||
| if (tensor_map_out.size() != outputs_shape_.at(0).size()) { | |||
| MS_LOG(ERROR) << "Out tensor map size is not equal to output size! Out tensor map size is " << tensor_map_out.size() | |||
| << " output size is " << outputs_shape_.at(0).size(); | |||
| return FAILED; | |||
| } | |||
| Shape tensor_map_in_index; | |||
| if (index_size_ >= 1) { | |||
| tensor_map_in_index.push_back(SizeToLong(size) - axis_ - 1); | |||
| } | |||
| for (size_t i = 1; i < index_size_; ++i) { | |||
| tensor_map_in_index.push_back(-1); | |||
| } | |||
| inputs_tensor_map_.emplace_back(std::move(tensor_map_in)); | |||
| inputs_tensor_map_.emplace_back(std::move(tensor_map_in_index)); | |||
| outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); | |||
| return SUCCESS; | |||
| } | |||
| Status GatherInfo::InferTensorInfo() { | |||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | |||
| MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | |||
| << inputs_shape_.size(); | |||
| return FAILED; | |||
| } | |||
| if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { | |||
| MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " | |||
| << outputs_shape_.size(); | |||
| return FAILED; | |||
| } | |||
| if (inputs_tensor_map_.size() != GATHER_V2_INPUTS_SIZE) { | |||
| MS_LOG(ERROR) << name_ << ": inputs tensor map size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | |||
| << inputs_tensor_map_.size(); | |||
| return FAILED; | |||
| } | |||
| if (outputs_tensor_map_.size() != GATHER_V2_OUTPUTS_SIZE) { | |||
| MS_LOG(ERROR) << name_ << ": outputs tensor map size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " | |||
| << outputs_tensor_map_.size(); | |||
| return FAILED; | |||
| } | |||
| // infer tensor shape | |||
| Shape input_shape = inputs_shape_.at(0); | |||
| Shape input_index_shape = inputs_shape_.at(1); | |||
| Shape output_shape = outputs_shape_.at(0); | |||
| TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; | |||
| if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || | |||
| (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || | |||
| (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) { | |||
| return FAILED; | |||
| } | |||
| TensorInfo input_tensor_info(input_tensor_layout); | |||
| TensorInfo input_index_info(input_index_layout); | |||
| TensorInfo output_tensor_info(output_tensor_layout); | |||
| inputs_tensor_info_.push_back(input_tensor_info); | |||
| inputs_tensor_info_.push_back(input_index_info); | |||
| outputs_tensor_info_.push_back(output_tensor_info); | |||
| return SUCCESS; | |||
| } | |||
| OperatorVector CreateSubOp(int64_t sub_value) { | |||
| OperatorVector ops; | |||
| OperatorName operator_name = SUB; | |||
| OperatorAttrs operator_attrs; | |||
| std::vector<int64_t> tensor_data = {sub_value}; | |||
| mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tensor_data, kInt32); | |||
| ValuePtr op_param_value = MakeValue(tensor_ptr); | |||
| Attr op1_param = std::make_pair("", op_param_value); | |||
| OperatorParams operator_param = {std::make_pair(op1_param, 2)}; | |||
| OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); | |||
| Operator op = std::make_pair(operator_name, operator_args); | |||
| ops.push_back(op); | |||
| return ops; | |||
| } | |||
| Status GatherInfo::InferTensorSubOps() { | |||
| sub_ops_.clear(); | |||
| if ((index_size_ == 0) || (axis_strategy_ == 1)) { | |||
| return SUCCESS; | |||
| } | |||
| int64_t mod_n = 1; | |||
| for (size_t i = LongToSize(axis_) + 1; i < dev_matrix_shape_.size(); i++) { | |||
| mod_n *= dev_matrix_shape_.at(i); | |||
| } | |||
| if ((axis_ >= SizeToLong(dev_matrix_shape_.size())) || axis_ < 0) { | |||
| MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ")."; | |||
| } | |||
| int64_t mod_p = mod_n * dev_matrix_shape_.at(LongToSize(axis_)); | |||
| int64_t rank = g_device_manager->rank_index_in_stage(); | |||
| int64_t mod_rank = rank % mod_p; | |||
| mod_rank = static_cast<int64_t>(mod_rank / mod_n); | |||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | |||
| MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | |||
| << inputs_shape_.size(); | |||
| return FAILED; | |||
| } | |||
| if ((axis_ >= SizeToLong(inputs_shape_.at(0).size())) || axis_ < 0) { | |||
| MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << inputs_shape_.at(0).size() << ")."; | |||
| } | |||
| int64_t sub_value = inputs_shape_[0][LongToSize(axis_)] / dev_matrix_shape_[LongToSize(axis_)] * mod_rank; | |||
| OperatorVector sub_op; | |||
| sub_ops_.emplace_back(std::move(sub_op)); | |||
| sub_op = CreateSubOp(sub_value); | |||
| sub_ops_.emplace_back(std::move(sub_op)); | |||
| return SUCCESS; | |||
| } | |||
| Status GatherInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| Status status = InferTensorSubOps(); | |||
| if (status != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferTensorSubOps failed."; | |||
| return status; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status GatherInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| std::vector<StrategyPtr> GatherInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| if ((inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) || (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE)) { | |||
| MS_LOG(EXCEPTION) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" | |||
| << outputs_shape_.size() << "is wrong."; | |||
| } | |||
| Shape input0_split(inputs_shape_[0].size(), 1); | |||
| Shapes splittable_inputs = {input0_split}; | |||
| std::vector<StrategyPtr> sp_vector; | |||
| if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) != | |||
| SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed."; | |||
| } | |||
| return sp_vector; | |||
| } | |||
| Status GatherInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||
| std::shared_ptr<Strategys> GatherInfo::GenerateBatchStrategies() { | |||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | |||
| MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | |||
| << inputs_shape_.size(); | |||
| } | |||
| if (GetAttrs() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "GetAttrs failed!"; | |||
| } | |||
| Dimensions strategy; | |||
| if (index_size_ != 1) { | |||
| strategy.push_back(1); | |||
| } else { | |||
| strategy.push_back(stage_device_size_); | |||
| } | |||
| for (size_t i = 1; i < inputs_shape_[0].size(); i++) { | |||
| strategy.push_back(1); | |||
| } | |||
| Strategys strategy_v = {strategy}; | |||
| return std::make_shared<Strategys>(strategy_v); | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -1,73 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "ir/value.h" | |||
| #include "frontend/parallel/auto_parallel/operator_costmodel.h" | |||
| #include "frontend/parallel/ops_info/operator_info.h" | |||
| #include "frontend/parallel/strategy.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| constexpr size_t GATHER_V2_INPUTS_SIZE = 2; | |||
| constexpr size_t GATHER_V2_OUTPUTS_SIZE = 1; | |||
| constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3; | |||
| // We now supported limited parallel strategies. | |||
| // If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of | |||
| // the input. | |||
| // If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1. | |||
| class GatherInfo : public OperatorInfo { | |||
| public: | |||
| GatherInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2Cost>()), | |||
| axis_(-1), | |||
| index_size_(0), | |||
| axis_strategy_(1) {} | |||
| ~GatherInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | |||
| protected: | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| Status InferMirrorOps() override { return SUCCESS; } | |||
| Status InferForwardCommunication() override { return SUCCESS; } | |||
| Status InferTensorInfo() override; | |||
| Status InferDevMatrixShape() override; | |||
| Status InferTensorMap() override; | |||
| Status GetAttrs() override; | |||
| private: | |||
| Status InferTensorSubOps(); | |||
| int64_t axis_; | |||
| size_t index_size_; | |||
| int64_t axis_strategy_; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ | |||
| @@ -24,7 +24,6 @@ | |||
| #include "frontend/parallel/ops_info/comparison_function_info.h" | |||
| #include "frontend/parallel/ops_info/dropout_do_mask_info.h" | |||
| #include "frontend/parallel/ops_info/elementary_function_info.h" | |||
| #include "frontend/parallel/ops_info/gather_v2_info.h" | |||
| #include "frontend/parallel/ops_info/get_next_info.h" | |||
| #include "frontend/parallel/ops_info/l2_normalize_info.h" | |||
| #include "frontend/parallel/ops_info/layer_norm_info.h" | |||
| @@ -37,7 +36,7 @@ | |||
| #include "frontend/parallel/ops_info/transpose_info.h" | |||
| #include "frontend/parallel/ops_info/unsorted_segment_op_info.h" | |||
| #include "frontend/parallel/ops_info/virtual_dataset_info.h" | |||
| #include "frontend/parallel/ops_info/gather_v2_p_info.h" | |||
| #include "frontend/parallel/ops_info/gather_info.h" | |||
| #include "frontend/parallel/ops_info/tile_info.h" | |||
| #include "frontend/parallel/ops_info/strided_slice_info.h" | |||
| #include "frontend/parallel/ops_info/slice_info.h" | |||
| @@ -1252,20 +1252,6 @@ OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveA | |||
| MS_LOG(EXCEPTION) << "Length of name is zero!"; | |||
| } | |||
| std::string distribute_opname = GetDisOpName(name); | |||
| if (name == GATHERV2) { | |||
| distribute_opname = name + "PInfo"; | |||
| auto data_parallel_iter = attrs.find(DATA_PARALLEL); | |||
| if (data_parallel_iter != attrs.end()) { | |||
| MS_EXCEPTION_IF_NULL(data_parallel_iter->second); | |||
| if (!data_parallel_iter->second->isa<BoolImm>()) { | |||
| MS_LOG(EXCEPTION) << ": data_parallel flag's type is not a bool."; | |||
| } | |||
| bool data_parallel = data_parallel_iter->second->cast<BoolImmPtr>()->value(); | |||
| if (data_parallel) { | |||
| distribute_opname = name + "Info"; | |||
| } | |||
| } | |||
| } | |||
| OperatorInfoPtr operator_ = | |||
| (OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS); | |||
| if (operator_ == nullptr) { | |||
| @@ -2632,9 +2618,9 @@ ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_ | |||
| return param_names; | |||
| } | |||
| bool IsGatherPInfo(const std::string &name) { | |||
| std::vector<std::string> gather_p_info_names = {"GatherPInfo", "SparseGatherV2Info", "EmbeddingLookupInfo"}; | |||
| for (std::string info_name : gather_p_info_names) { | |||
| bool IsGatherInfo(const std::string &name) { | |||
| std::vector<std::string> gather_info_names = {"GatherInfo", "SparseGatherV2Info", "EmbeddingLookupInfo"}; | |||
| for (std::string info_name : gather_info_names) { | |||
| if (name.find(info_name) != std::string::npos) { | |||
| return true; | |||
| } | |||
| @@ -2669,10 +2655,10 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGrap | |||
| for (auto param_name_pair : param_names) { | |||
| tensor_info_map[param_name_pair.first] = param_name_pair.second->user_data<TensorLayout>(); | |||
| } | |||
| if (IsGatherPInfo(operator_info->name())) { | |||
| auto gatherv2_info = std::dynamic_pointer_cast<GatherPInfo>(operator_info); | |||
| auto param_split_shapes = gatherv2_info->param_split_shapes(); | |||
| auto index_offsets = gatherv2_info->index_offsets(); | |||
| if (IsGatherInfo(operator_info->name())) { | |||
| auto gather_info = std::dynamic_pointer_cast<GatherInfo>(operator_info); | |||
| auto param_split_shapes = gather_info->param_split_shapes(); | |||
| auto index_offsets = gather_info->index_offsets(); | |||
| if (param_split_shapes.size() != index_offsets.size()) { | |||
| MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same."; | |||
| } | |||
| @@ -1,240 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import ParameterTuple | |||
| from mindspore.communication.management import init | |||
| from mindspore.nn import Dense, Cell | |||
| from mindspore.nn.loss.loss import LossBase | |||
| from mindspore.nn.optim import Momentum | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.train import Model | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.communication._comm_helper import GlobalComm | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| device_number = 32 | |||
| batch_size_per_device = 128 | |||
| class Dataset(): | |||
| def __init__(self, predict, length=3): | |||
| self.predict = predict | |||
| self.index = 0 | |||
| self.length = length | |||
| def __iter__(self): | |||
| return self | |||
| def __next__(self): | |||
| if self.index >= self.length: | |||
| raise StopIteration | |||
| self.index += 1 | |||
| return (self.predict,) | |||
| def reset(self): | |||
| self.index = 0 | |||
| def get_dataset_size(self): | |||
| return 128 | |||
| def get_repeat_count(self): | |||
| return 1 | |||
| def create_tuple_iterator(self, num_epochs=-1, do_copy=True): | |||
| return self | |||
| class GatherV2(LossBase): | |||
| def __init__(self, index_dim, strategy, index_size=16): | |||
| super(GatherV2, self).__init__() | |||
| self.pow = P.Pow() | |||
| emb1_list = 21 | |||
| emb2_list = 2 | |||
| if index_dim == 1: | |||
| emb_list = list(range(index_size)) | |||
| emb1_list = emb_list[0::2] | |||
| emb2_list = emb_list[1::2] | |||
| if index_dim == 2: | |||
| emb_list = np.arange(index_size * 16) | |||
| emb1_list = np.reshape(emb_list[0::2], (int(index_size / 2), 16)) | |||
| emb2_list = np.reshape(emb_list[1::2], (int(index_size / 2), 16)) | |||
| self.emb1_param = Tensor(emb1_list, dtype=mstype.int32) | |||
| self.emb2_param = Tensor(emb2_list, dtype=mstype.int32) | |||
| self.gatherv2 = P.Gather().shard(strategy).add_prim_attr("data_parallel", True) | |||
| def construct(self, nembeddings): | |||
| emb1 = self.gatherv2(nembeddings, self.emb1_param, 0) | |||
| emb2 = self.gatherv2(nembeddings, self.emb2_param, 0) | |||
| return self.pow((emb1 - emb2), 2.0) | |||
| def fc_with_initialize(input_channels, out_channels): | |||
| return Dense(input_channels, out_channels) | |||
| class BuildTrainNetwork(nn.Cell): | |||
| def __init__(self, network, criterion): | |||
| super(BuildTrainNetwork, self).__init__() | |||
| self.network = network | |||
| self.criterion = criterion | |||
| def construct(self, input_data): | |||
| embeddings = self.network(input_data) | |||
| loss = self.criterion(embeddings) | |||
| return loss | |||
| class TrainOneStepCell(Cell): | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| super(TrainOneStepCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.add_flags(defer_inline=True) | |||
| self.weights = ParameterTuple(network.trainable_params()) | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation(get_by_list=True, | |||
| sens_param=True) | |||
| self.sens = sens | |||
| def construct(self, data): | |||
| weights = self.weights | |||
| loss = self.network(data) | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| grads = self.grad(self.network, weights)(data, sens) | |||
| self.optimizer(grads) | |||
| return loss | |||
| def net_trains(criterion, rank): | |||
| GlobalComm.CHECK_ENVS = False | |||
| init() | |||
| GlobalComm.CHECK_ENVS = True | |||
| lr = 0.1 | |||
| momentum = 0.9 | |||
| max_epoch = 20 | |||
| input_channels = 256 | |||
| out_channels = 512 | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number, | |||
| global_rank=rank) | |||
| predict = Tensor(np.ones([batch_size_per_device, input_channels]), dtype=ms.float32) | |||
| dataset = Dataset(predict, 4) | |||
| network = fc_with_initialize(input_channels, out_channels) | |||
| network.set_train() | |||
| train_network = BuildTrainNetwork(network, criterion) | |||
| train_network.set_train() | |||
| opt = Momentum(train_network.trainable_params(), lr, momentum) | |||
| train_net = TrainOneStepCell(train_network, opt).set_train() | |||
| model = Model(train_net) | |||
| model.train(max_epoch, dataset, dataset_sink_mode=False) | |||
| context.reset_auto_parallel_context() | |||
| def test_auto_batch_parallel(): | |||
| gather_v2_strategy = None | |||
| criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) | |||
| rank = 2 | |||
| net_trains(criterion, rank) | |||
| def test_2d_index_auto_batch_parallel(): | |||
| gather_v2_strategy = None | |||
| criterion = GatherV2(2, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) | |||
| rank = 2 | |||
| net_trains(criterion, rank) | |||
| def test_batch_parallel(): | |||
| gather_v2_strategy = ((device_number, 1),) | |||
| criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) | |||
| rank = 2 | |||
| net_trains(criterion, rank) | |||
| def test_strategy1(): | |||
| gather_v2_strategy = ((16, 2),) | |||
| rank = 2 | |||
| criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) | |||
| net_trains(criterion, rank) | |||
| def test_strategy2(): | |||
| gather_v2_strategy = ((1, device_number),) | |||
| rank = 2 | |||
| criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) | |||
| net_trains(criterion, rank) | |||
| def test_strategy3(): | |||
| gather_v2_strategy = ((8, 1),) | |||
| rank = 2 | |||
| criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) | |||
| net_trains(criterion, rank) | |||
| class GatherV2Axis1(LossBase): | |||
| def __init__(self, index_dim, strategy, index_size=16): | |||
| super(GatherV2Axis1, self).__init__() | |||
| self.pow = P.Pow() | |||
| emb1_list = 21 | |||
| emb2_list = 2 | |||
| if index_dim == 1: | |||
| emb_list = list(range(index_size)) | |||
| emb1_list = emb_list[0::2] | |||
| emb2_list = emb_list[1::2] | |||
| if index_dim == 2: | |||
| emb_list = np.arange(index_size * index_size) | |||
| emb1_list = np.reshape(emb_list[0::2], (int(index_size / 2), index_size)) | |||
| emb2_list = np.reshape(emb_list[1::2], (int(index_size / 2), index_size)) | |||
| self.emb1_param = Tensor(emb1_list, dtype=mstype.int32) | |||
| self.emb2_param = Tensor(emb2_list, dtype=mstype.int32) | |||
| self.gatherv2 = P.Gather().shard(strategy) | |||
| def construct(self, nembeddings): | |||
| emb1 = self.gatherv2(nembeddings, self.emb1_param, 1) | |||
| emb2 = self.gatherv2(nembeddings, self.emb2_param, 1) | |||
| return self.pow((emb1 - emb2), 2.0) | |||
| def test_axis1_auto_batch_parallel(): | |||
| gather_v2_strategy = None | |||
| criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) | |||
| rank = 2 | |||
| net_trains(criterion, rank) | |||
| def test_axis1_batch_parallel(): | |||
| gather_v2_strategy = ((device_number, 1), (1,)) | |||
| criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) | |||
| rank = 2 | |||
| net_trains(criterion, rank) | |||
| def test_axis1_strategy1(): | |||
| gather_v2_strategy = ((16, 2), (1,)) | |||
| rank = 17 | |||
| criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) | |||
| net_trains(criterion, rank) | |||