Merge pull request !30133 from liuluobin/rnd_fix_masterfeature/build-system-rewrite
| @@ -19,6 +19,18 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| int64_t RandomChoiceWithMaskInfo::SEED_NUM = 1; | |||
| Status RandomChoiceWithMaskInfo::GetAttrs() { | |||
| if (attrs_.find(SEED) != attrs_.end()) { | |||
| seed_ = GetValue<int64_t>(attrs_[SEED]); | |||
| } | |||
| if (attrs_.find(SEED2) != attrs_.end()) { | |||
| seed2_ = GetValue<int64_t>(attrs_[SEED2]); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status RandomChoiceWithMaskInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | |||
| @@ -71,5 +83,39 @@ Status RandomChoiceWithMaskInfo::InferAsLossDivisor() { | |||
| << as_loss_divisor_; | |||
| return SUCCESS; | |||
| } | |||
| void RandomChoiceWithMaskInfo::ReplaceNodeInputOrAttrs() { | |||
| if (seed_ != 0 || seed2_ != 0) { | |||
| return; | |||
| } | |||
| if (cnode_->HasAttr(SEED)) { | |||
| cnode_->EraseAttr(SEED); | |||
| } | |||
| if (cnode_->HasAttr(SEED2)) { | |||
| cnode_->EraseAttr(SEED2); | |||
| } | |||
| cnode_->AddAttr(SEED, MakeValue(SEED_NUM)); | |||
| cnode_->AddAttr(SEED2, MakeValue(SEED_NUM)); | |||
| ++SEED_NUM; | |||
| } | |||
| void RandomChoiceWithMaskInfo::CheckGPUBackend() { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| if (backend != kGPUDevice) { | |||
| MS_LOG(EXCEPTION) << name_ << ": The backend is " << backend << " , only support on GPU backend now."; | |||
| } | |||
| } | |||
| Status RandomChoiceWithMaskInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) { | |||
| CheckGPUBackend(); | |||
| return OperatorInfo::Init(in_strategy, out_strategy); | |||
| } | |||
| Status RandomChoiceWithMaskInfo::InitForCostModel(const StrategyPtr &strategy, const StrategyPtr &out_strategy) { | |||
| CheckGPUBackend(); | |||
| return OperatorInfo::InitForCostModel(strategy, out_strategy); | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -31,18 +31,30 @@ class RandomChoiceWithMaskInfo : public OperatorInfo { | |||
| public: | |||
| RandomChoiceWithMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RandomChoicWithMaskCost>()) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RandomChoicWithMaskCost>()), | |||
| seed_(0), | |||
| seed2_(0) {} | |||
| ~RandomChoiceWithMaskInfo() = default; | |||
| Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy, const StrategyPtr &out_strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); } | |||
| void ReplaceNodeInputOrAttrs() override; | |||
| protected: | |||
| Status GetAttrs() override { return SUCCESS; } | |||
| Status GetAttrs() override; | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| Status InferDevMatrixShape() override; | |||
| Status InferTensorMap() override; | |||
| Status InferForwardCommunication() override { return SUCCESS; } | |||
| Status InferAsLossDivisor() override; | |||
| private: | |||
| void CheckGPUBackend(); | |||
| int64_t seed_; | |||
| int64_t seed2_; | |||
| static int64_t SEED_NUM; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -51,6 +51,7 @@ def test_auto_parallel_random_choice_with_mask(): | |||
| Description: auto parallel | |||
| Expectation: compile success | |||
| """ | |||
| context.set_context(device_target="GPU") | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) | |||
| net = Net() | |||
| compile_net(net, _input_x) | |||
| @@ -62,9 +63,24 @@ def test_random_choice_with_mask_wrong_strategy(): | |||
| Description: illegal strategy | |||
| Expectation: raise RuntimeError | |||
| """ | |||
| context.set_context(device_target="GPU") | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy = ((8, 1),) | |||
| net = Net(strategy) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net, _input_x) | |||
| context.reset_auto_parallel_context() | |||
| def test_random_choice_with_mask_not_gpu(): | |||
| """ | |||
| Feature: RandomChoiceWithMask | |||
| Description: not compile with gpu backend | |||
| Expectation: raise RuntimeError | |||
| """ | |||
| context.set_context(device_target="Ascend") | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) | |||
| net = Net() | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net, _input_x) | |||
| context.reset_auto_parallel_context() | |||