From: @yangzhenzhang Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengpull/15227/MERGE
| @@ -608,6 +608,7 @@ using GreaterCost = SubCost; | |||||
| using GreaterEqualCost = SubCost; | using GreaterEqualCost = SubCost; | ||||
| using LessCost = SubCost; | using LessCost = SubCost; | ||||
| using LessEqualCost = SubCost; | using LessEqualCost = SubCost; | ||||
| using GatherNdCost = SubCost; | |||||
| class MulCost : public SubCost { | class MulCost : public SubCost { | ||||
| public: | public: | ||||
| @@ -191,6 +191,7 @@ REGISTER(StackInfo); | |||||
| REGISTER(ConcatInfo); | REGISTER(ConcatInfo); | ||||
| REGISTER(SplitInfo); | REGISTER(SplitInfo); | ||||
| REGISTER(UniqueInfo); | REGISTER(UniqueInfo); | ||||
| REGISTER(GatherNdInfo); | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,214 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "frontend/parallel/ops_info/gathernd_info.h" | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <functional> | |||||
| #include <string> | |||||
| #include "frontend/parallel/device_matrix.h" | |||||
| #include "frontend/parallel/strategy.h" | |||||
| #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | |||||
| #include "pipeline/jit/resource.h" | |||||
| namespace mindspore { | |||||
| namespace parallel { | |||||
| // the input can not be split, and the last dimension of indices can not be split | |||||
| Status GatherNdInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| MS_EXCEPTION_IF_NULL(strategy); | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy"; | |||||
| return FAILED; | |||||
| } | |||||
| std::vector<Dimensions> stra = strategy->GetInputDim(); | |||||
| if (stra.size() != 2) { | |||||
| MS_LOG(ERROR) << name_ << ": The size of strategies must be 2"; | |||||
| return FAILED; | |||||
| } | |||||
| int64_t input_split_size = std::accumulate(stra[0].begin(), stra[0].end(), 1, std::multiplies<int64_t>()); | |||||
| if (input_split_size != 1) { | |||||
| MS_LOG(ERROR) << name_ << ": The input can not be split"; | |||||
| return FAILED; | |||||
| } | |||||
| if (stra[1].empty()) { | |||||
| MS_LOG(ERROR) << name_ << ": The strategy of indices can not be empty"; | |||||
| return FAILED; | |||||
| } | |||||
| if (stra[1].back() != 1) { | |||||
| MS_LOG(ERROR) << name_ << ": The last dimension of indices can not be split"; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| // the dev matrix is indices_strategy | |||||
| Status GatherNdInfo::InferDevMatrixShape() { | |||||
| MS_EXCEPTION_IF_NULL(strategy_); | |||||
| std::vector<Dimensions> stra = strategy_->GetInputDim(); | |||||
| if (stra.size() != 2) { | |||||
| MS_LOG(ERROR) << name_ << "The size of strategies must be 2"; | |||||
| return FAILED; | |||||
| } | |||||
| dev_matrix_shape_ = stra[1]; | |||||
| return SUCCESS; | |||||
| } | |||||
| // input shape: [x, y, z], indices shape: [a, b, c, 2], output shape: [a, b, c, z] | |||||
| // strategy: ((1, 1, 1), (m, n, o, 1)) | |||||
| // dev-matrix: [m, n, o, 1] | |||||
| // input map: [-1, -1, -1], indices map: [3, 2, 1, 0], output map: [3, 2, 1, -1] | |||||
| Status GatherNdInfo::InferTensorMap() { | |||||
| if (inputs_shape_.size() != 2) { | |||||
| MS_LOG(ERROR) << name_ << "The size of input shapes must be 2"; | |||||
| return FAILED; | |||||
| } | |||||
| if (outputs_shape_.empty() || outputs_shape_[0].size() < (inputs_shape_[1].size() - 1)) { | |||||
| MS_LOG(ERROR) << name_ << "invalid shapes"; | |||||
| return FAILED; | |||||
| } | |||||
| TensorMap input_tensor_map(inputs_shape_[0].size(), MAP_NONE); // the input can not split | |||||
| // cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices. | |||||
| TensorMap indices_tensor_map; | |||||
| int64_t size = SizeToLong(inputs_shape_[0].size()); | |||||
| for (int64_t i = 0; i < size; ++i) { | |||||
| indices_tensor_map.push_back(size - i - 1); | |||||
| } | |||||
| TensorMap output_tensor_map(outputs_shape_[0].size(), MAP_NONE); | |||||
| for (size_t i = 0; i < (inputs_shape_[1].size() - 1); ++i) { | |||||
| output_tensor_map[i] = indices_tensor_map[i]; | |||||
| } | |||||
| inputs_tensor_map_.push_back(input_tensor_map); | |||||
| inputs_tensor_map_.push_back(indices_tensor_map); | |||||
| outputs_tensor_map_.push_back(output_tensor_map); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GatherNdInfo::InferTensorInfo() { | |||||
| if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid args"; | |||||
| return FAILED; | |||||
| } | |||||
| TensorLayout input_layout, output_layout; | |||||
| for (size_t i = 0; i < inputs_shape_.size(); ++i) { | |||||
| // infer tensor layout | |||||
| if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed."; | |||||
| return FAILED; | |||||
| } | |||||
| TensorInfo input_tensor_info(input_layout); | |||||
| inputs_tensor_info_.push_back(input_tensor_info); | |||||
| } | |||||
| if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed."; | |||||
| return FAILED; | |||||
| } | |||||
| TensorInfo output_tensor_info(output_layout); | |||||
| outputs_tensor_info_.push_back(output_tensor_info); | |||||
| return SUCCESS; | |||||
| } | |||||
| void GatherNdInfo::ReComputeBatchSplitFlagList() { | |||||
| split_flag_list_[0] = false; | |||||
| split_flag_list_[1] = true; | |||||
| } | |||||
| Status GatherNdInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| Status GatherNdInfo::GenerateStrategies(int64_t stage_id) { | |||||
| if (InferAttrs() != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Infer attrs failed"; | |||||
| return FAILED; | |||||
| } | |||||
| if (inputs_shape_.empty()) { | |||||
| MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; | |||||
| return FAILED; | |||||
| } | |||||
| // to generate the indices' strategy | |||||
| Shape input_split(inputs_shape_[1].size(), 1); | |||||
| input_split.back() = 0; | |||||
| Shapes splittable_input = {input_split}; | |||||
| Shapes tmp_inputs_shape = {inputs_shape_[1]}; | |||||
| std::vector<StrategyPtr> sp_vector; | |||||
| if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Generate strategies failed"; | |||||
| return FAILED; | |||||
| } | |||||
| // the others strategies are equal to the first input's strategy | |||||
| for (auto &sp : sp_vector) { | |||||
| if ((sp == nullptr) || sp->GetInputDim().empty()) { | |||||
| MS_LOG(ERROR) << name_ << ": The strategy is null or empty"; | |||||
| return FAILED; | |||||
| } | |||||
| Strategys tmp_strategy; | |||||
| Dimensions indices_strategy = sp->GetInputDim()[0]; | |||||
| Dimensions input_strategy(inputs_shape_[0].size(), 1); | |||||
| tmp_strategy.push_back(input_strategy); | |||||
| tmp_strategy.push_back(indices_strategy); | |||||
| sp->ResetInputs(tmp_strategy); | |||||
| } | |||||
| size_t success = 0; | |||||
| for (auto &sp : sp_vector) { | |||||
| PrintStrategy(sp); | |||||
| if (SetCostUnderStrategy(sp) == SUCCESS) { | |||||
| success++; | |||||
| MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; | |||||
| PrintStrategy(sp); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GatherNdInfo::Init(const StrategyPtr &strategy) { | |||||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||||
| return FAILED; | |||||
| } | |||||
| MS_LOG(INFO) << name_ << ": Init success."; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GatherNdInfo::InitForCostModel(const StrategyPtr &strategy) { | |||||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||||
| return FAILED; | |||||
| } | |||||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace parallel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHERND_INFO_H_ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHERND_INFO_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "ir/value.h" | |||||
| #include "frontend/parallel/auto_parallel/operator_costmodel.h" | |||||
| #include "frontend/parallel/ops_info/operator_info.h" | |||||
| #include "frontend/parallel/strategy.h" | |||||
| namespace mindspore { | |||||
| namespace parallel { | |||||
| class GatherNdInfo : public OperatorInfo { | |||||
| public: | |||||
| GatherNdInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||||
| const PrimitiveAttrs &attrs) | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherNdCost>()) {} | |||||
| ~GatherNdInfo() override = default; | |||||
| Status Init(const StrategyPtr &strategy) override; | |||||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||||
| Status GenerateStrategies(int64_t) override; | |||||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||||
| void ReComputeBatchSplitFlagList() override; | |||||
| protected: | |||||
| Status GetAttrs() override { return SUCCESS; } | |||||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||||
| Status InferForwardCommunication() override { return SUCCESS; } | |||||
| Status InferTensorInfo() override; | |||||
| Status InferDevMatrixShape() override; | |||||
| Status InferTensorMap() override; | |||||
| }; | |||||
| using GatherNdInfoPtr = std::shared_ptr<GatherNdInfo>; | |||||
| } // namespace parallel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHERND_INFO_H_ | |||||
| @@ -50,5 +50,6 @@ | |||||
| #include "frontend/parallel/ops_info/unique_info.h" | #include "frontend/parallel/ops_info/unique_info.h" | ||||
| #include "frontend/parallel/ops_info/uniform_candidate_sampler_info.h" | #include "frontend/parallel/ops_info/uniform_candidate_sampler_info.h" | ||||
| #include "frontend/parallel/ops_info/reluv2_info.h" | #include "frontend/parallel/ops_info/reluv2_info.h" | ||||
| #include "frontend/parallel/ops_info/gathernd_info.h" | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ | #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ | ||||
| @@ -325,6 +325,7 @@ constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; | |||||
| constexpr char DROPOUT[] = "Dropout"; | constexpr char DROPOUT[] = "Dropout"; | ||||
| constexpr char KStridedSlice[] = "StridedSlice"; | constexpr char KStridedSlice[] = "StridedSlice"; | ||||
| constexpr char UNIQUE[] = "Unique"; | constexpr char UNIQUE[] = "Unique"; | ||||
| constexpr char GATHERND[] = "GatherNd"; | |||||
| // Parallel don't care | // Parallel don't care | ||||
| constexpr char STRING_EQUAL[] = "string_equal"; | constexpr char STRING_EQUAL[] = "string_equal"; | ||||
| @@ -163,7 +163,8 @@ bool IsSplittableOperator(const std::string &op_name) { | |||||
| BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, | BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, | ||||
| SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, | SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, | ||||
| UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, | UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, | ||||
| UNSORTED_SEGMENT_MAX}; | |||||
| UNSORTED_SEGMENT_MAX, GATHER_ND}; | |||||
| // clang-format on | // clang-format on | ||||
| auto iter = splittable_op.find(op_name); | auto iter = splittable_op.find(op_name); | ||||
| @@ -492,10 +493,9 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||||
| std::map<size_t, size_t> loop_to_ops; | std::map<size_t, size_t> loop_to_ops; | ||||
| // extract strategy from checkpoint for multi-train | // extract strategy from checkpoint for multi-train | ||||
| StrategyMap stra_map; | StrategyMap stra_map; | ||||
| if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { | |||||
| if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | |||||
| } | |||||
| if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() && | |||||
| StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | |||||
| } | } | ||||
| std::vector<std::string> last_forward_node_ids; | std::vector<std::string> last_forward_node_ids; | ||||
| if (!root->has_flag(TRAINING)) { | if (!root->has_flag(TRAINING)) { | ||||
| @@ -505,8 +505,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||||
| for (auto &node : all_nodes) { | for (auto &node : all_nodes) { | ||||
| // NOTE: we only care about splittable Primitive operators | // NOTE: we only care about splittable Primitive operators | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0))); | |||||
| if (bool_result) { | |||||
| if ((cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)))) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ||||
| @@ -551,9 +550,8 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||||
| bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != | bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != | ||||
| last_forward_node_ids.end(); | last_forward_node_ids.end(); | ||||
| auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map); | auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map); | ||||
| if (operator_info == nullptr) { | |||||
| return FAILED; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(operator_info); | |||||
| // Needed by rec_parser | // Needed by rec_parser | ||||
| operator_info->set_type(prim->name()); | operator_info->set_type(prim->name()); | ||||
| operator_info->set_last_node_flag(is_last_nodes); | operator_info->set_last_node_flag(is_last_nodes); | ||||
| @@ -627,8 +625,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| MS_LOG(INFO) << "Constructing edges for cost graph begins."; | MS_LOG(INFO) << "Constructing edges for cost graph begins."; | ||||
| for (auto &node : all_nodes) { | for (auto &node : all_nodes) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| bool bool_result_cnode = (cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0)); | |||||
| if (bool_result_cnode) { | |||||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto &inputs = cnode->inputs(); | auto &inputs = cnode->inputs(); | ||||
| @@ -638,7 +635,6 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| } | } | ||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | ||||
| size_t edge_count = 0; | size_t edge_count = 0; | ||||
| auto node_op_info = cnode->user_data<OperatorInfo>(); | auto node_op_info = cnode->user_data<OperatorInfo>(); | ||||
| for (size_t i = 1; i < inputs.size(); ++i) { | for (size_t i = 1; i < inputs.size(); ++i) { | ||||
| @@ -0,0 +1,110 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore as ms | |||||
| from mindspore import context, Tensor, Parameter | |||||
| from mindspore.nn import Cell, Momentum | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.train import Model | |||||
| from tests.dataset_mock import MindData | |||||
| class Dataset(MindData): | |||||
| def __init__(self, predict, label, length=3): | |||||
| super(Dataset, self).__init__(size=length) | |||||
| self.predict = predict | |||||
| self.label = label | |||||
| self.index = 0 | |||||
| self.length = length | |||||
| def __iter__(self): | |||||
| return self | |||||
| def __next__(self): | |||||
| if self.index >= self.length: | |||||
| raise StopIteration | |||||
| self.index += 1 | |||||
| return self.predict, self.label | |||||
| def reset(self): | |||||
| self.index = 0 | |||||
| class Net(Cell): | |||||
| def __init__(self, w1, strategy1=None, strategy2=None): | |||||
| super().__init__() | |||||
| self.mul = P.Mul().shard(strategy1) | |||||
| self.w1 = Parameter(w1, "w1") | |||||
| self.indices = Tensor(np.ones([16, 2]), dtype=ms.int32) | |||||
| self.gathernd = P.GatherNd().shard(strategy2) | |||||
| def construct(self, x, b): | |||||
| out = self.mul(x, self.w1) | |||||
| out = self.gathernd(out, self.indices) | |||||
| return out | |||||
| _x = Tensor(np.ones([16, 64]), dtype=ms.float32) | |||||
| _b = Tensor(np.ones([16, 64]), dtype=ms.float32) | |||||
| _w1 = Tensor(np.ones([128, 64]), dtype=ms.float32) | |||||
| def compile_net(net): | |||||
| context.set_context(save_graphs=True) | |||||
| learning_rate = 0.1 | |||||
| momentum = 0.9 | |||||
| epoch_size = 2 | |||||
| dataset = Dataset(_x, _b) | |||||
| opt = Momentum(net.trainable_params(), learning_rate, momentum) | |||||
| model = Model(net, optimizer=opt) | |||||
| model.train(epoch_size, dataset, dataset_sink_mode=False) | |||||
| context.reset_auto_parallel_context() | |||||
| def test_gathernd_data_parallel(): | |||||
| context.set_auto_parallel_context( | |||||
| parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||||
| strategy1 = ((8, 1), (8, 1)) | |||||
| strategy2 = ((1, 1), (8, 1)) | |||||
| net = Net(_w1, strategy1, strategy2) | |||||
| compile_net(net) | |||||
| def test_gathernd_model_parallel(): | |||||
| context.set_auto_parallel_context( | |||||
| parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||||
| strategy1 = ((2, 4), (2, 4)) | |||||
| strategy2 = ((1, 1), (4, 1)) | |||||
| net = Net(_w1, strategy1, strategy2) | |||||
| compile_net(net) | |||||
| def test_gathernd_auto_parallel(): | |||||
| context.set_auto_parallel_context( | |||||
| parallel_mode="auto_parallel", device_num=8, global_rank=0) | |||||
| net = Net(_w1) | |||||
| compile_net(net) | |||||
| def test_gathernd_strategy_error(): | |||||
| context.set_auto_parallel_context( | |||||
| parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||||
| strategy1 = ((8, 1), (8, 1)) | |||||
| strategy2 = ((1, 1), (2, 4)) | |||||
| net = Net(_w1, strategy1, strategy2) | |||||
| with pytest.raises(RuntimeError): | |||||
| compile_net(net) | |||||