Browse Source

!30133 Add backend check for RandomChoiceWithMask

Merge pull request !30133 from liuluobin/rnd_fix_master
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
465f21a3ea
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 76 additions and 2 deletions
  1. +46
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/random_choice_with_mask_info.cc
  2. +14
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/random_choice_with_mask_info.h
  3. +16
    -0
      tests/ut/python/parallel/test_random_choice_with_mask.py

+ 46
- 0
mindspore/ccsrc/frontend/parallel/ops_info/random_choice_with_mask_info.cc View File

@@ -19,6 +19,18 @@

namespace mindspore {
namespace parallel {
int64_t RandomChoiceWithMaskInfo::SEED_NUM = 1;

Status RandomChoiceWithMaskInfo::GetAttrs() {
if (attrs_.find(SEED) != attrs_.end()) {
seed_ = GetValue<int64_t>(attrs_[SEED]);
}
if (attrs_.find(SEED2) != attrs_.end()) {
seed2_ = GetValue<int64_t>(attrs_[SEED2]);
}
return SUCCESS;
}

Status RandomChoiceWithMaskInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
@@ -71,5 +83,39 @@ Status RandomChoiceWithMaskInfo::InferAsLossDivisor() {
<< as_loss_divisor_;
return SUCCESS;
}

void RandomChoiceWithMaskInfo::ReplaceNodeInputOrAttrs() {
if (seed_ != 0 || seed2_ != 0) {
return;
}
if (cnode_->HasAttr(SEED)) {
cnode_->EraseAttr(SEED);
}
if (cnode_->HasAttr(SEED2)) {
cnode_->EraseAttr(SEED2);
}
cnode_->AddAttr(SEED, MakeValue(SEED_NUM));
cnode_->AddAttr(SEED2, MakeValue(SEED_NUM));
++SEED_NUM;
}

void RandomChoiceWithMaskInfo::CheckGPUBackend() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (backend != kGPUDevice) {
MS_LOG(EXCEPTION) << name_ << ": The backend is " << backend << " , only support on GPU backend now.";
}
}

Status RandomChoiceWithMaskInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
CheckGPUBackend();
return OperatorInfo::Init(in_strategy, out_strategy);
}

Status RandomChoiceWithMaskInfo::InitForCostModel(const StrategyPtr &strategy, const StrategyPtr &out_strategy) {
CheckGPUBackend();
return OperatorInfo::InitForCostModel(strategy, out_strategy);
}
} // namespace parallel
} // namespace mindspore

+ 14
- 2
mindspore/ccsrc/frontend/parallel/ops_info/random_choice_with_mask_info.h View File

@@ -31,18 +31,30 @@ class RandomChoiceWithMaskInfo : public OperatorInfo {
public:
RandomChoiceWithMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RandomChoicWithMaskCost>()) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RandomChoicWithMaskCost>()),
seed_(0),
seed2_(0) {}
~RandomChoiceWithMaskInfo() = default;
Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
Status InitForCostModel(const StrategyPtr &strategy, const StrategyPtr &out_strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
void ReplaceNodeInputOrAttrs() override;

protected:
Status GetAttrs() override { return SUCCESS; }
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferAsLossDivisor() override;

private:
void CheckGPUBackend();

int64_t seed_;
int64_t seed2_;
static int64_t SEED_NUM;
};
} // namespace parallel
} // namespace mindspore


+ 16
- 0
tests/ut/python/parallel/test_random_choice_with_mask.py View File

@@ -51,6 +51,7 @@ def test_auto_parallel_random_choice_with_mask():
Description: auto parallel
Expectation: compile success
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net()
compile_net(net, _input_x)
@@ -62,9 +63,24 @@ def test_random_choice_with_mask_wrong_strategy():
Description: illegal strategy
Expectation: raise RuntimeError
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((8, 1),)
net = Net(strategy)
with pytest.raises(RuntimeError):
compile_net(net, _input_x)
context.reset_auto_parallel_context()


def test_random_choice_with_mask_not_gpu():
"""
Feature: RandomChoiceWithMask
Description: not compile with gpu backend
Expectation: raise RuntimeError
"""
context.set_context(device_target="Ascend")
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net()
with pytest.raises(RuntimeError):
compile_net(net, _input_x)
context.reset_auto_parallel_context()

Loading…
Cancel
Save