Browse Source

!8063 fix ReLUV2 mask error

Merge pull request !8063 from yihuaijie/dev
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
33aa2ae16b
2 changed files with 31 additions and 7 deletions
  1. +29
    -5
      mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.cc
  2. +2
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.h

+ 29
- 5
mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.cc View File

@@ -33,12 +33,27 @@ namespace mindspore {
namespace parallel {
Status ReLUV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }

Status ReLUV2Info::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
Status ReLUV2Info::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED;
}

Strategys stra = strategy->GetInputDim();
Dimensions input_strategy = stra.at(0);
if (input_strategy[1] != 1) {
MS_LOG(ERROR) << name_ << "The second dimension is not splitable.";
return FAILED;
}

return SUCCESS;
}

Status ReLUV2Info::GetAttrs() { return SUCCESS; }

Status ReLUV2Info::GenerateStrategies(int32_t stage_id) {
Shape input0_split(inputs_shape_[0].size(), 1);
// the second dimension is not splitable
input0_split[1] = 0;
Shapes splittable_inputs = {input0_split};

std::vector<StrategyPtr> sp_vector;
@@ -97,6 +112,7 @@ Status ReLUV2Info::InferForwardCommunication() {

Status ReLUV2Info::InferTensorMap() {
Shape tensor_map_index;
Shape tensor_map_mask;
size_t size = inputs_shape_.at(0).size();
// such as 4: tensor_map_index [3,2,1,0]
for (size_t i = 0; i < size; ++i) {
@@ -104,9 +120,12 @@ Status ReLUV2Info::InferTensorMap() {
}

inputs_tensor_map_.push_back(tensor_map_index);
// output and mask
outputs_tensor_map_.push_back(tensor_map_index);
// output
outputs_tensor_map_.push_back(tensor_map_index);
tensor_map_mask = tensor_map_index;
// mask format NC1HWC0
tensor_map_mask.push_back(MAP_NONE);
outputs_tensor_map_.push_back(tensor_map_mask);
return SUCCESS;
}

@@ -116,7 +135,7 @@ Status ReLUV2Info::InferTensorInfo() {
return FAILED;
}

TensorLayout input_layout, output_layout;
TensorLayout input_layout, output_layout, mask_layout;
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
@@ -129,10 +148,15 @@ Status ReLUV2Info::InferTensorInfo() {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
if (mask_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[1], outputs_shape_[1]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
TensorInfo mask_tensor_info(mask_layout);
// output and mask
outputs_tensor_info_.push_back(output_tensor_info);
outputs_tensor_info_.push_back(output_tensor_info);
outputs_tensor_info_.push_back(mask_tensor_info);
return SUCCESS;
}



+ 2
- 2
mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.h View File

@@ -30,8 +30,8 @@
namespace mindspore {
namespace parallel {
/*
* The input, output and mask have the same tensormap.
* And all dimensions of input are splitable.
* The second dimension is not splitable, as mask is caculated along it.
* The input and output have the same tensormap (3, 2, 1, 0), mask's tensormap is (3, 2, 1, 0, -1)
*/
class ReLUV2Info : public OperatorInfo {
public:


Loading…
Cancel
Save