|
|
|
@@ -33,12 +33,27 @@ namespace mindspore { |
|
|
|
namespace parallel { |
|
|
|
Status ReLUV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } |
|
|
|
|
|
|
|
Status ReLUV2Info::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); } |
|
|
|
Status ReLUV2Info::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
Strategys stra = strategy->GetInputDim(); |
|
|
|
Dimensions input_strategy = stra.at(0); |
|
|
|
if (input_strategy[1] != 1) { |
|
|
|
MS_LOG(ERROR) << name_ << "The second dimension is not splitable."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status ReLUV2Info::GetAttrs() { return SUCCESS; } |
|
|
|
|
|
|
|
Status ReLUV2Info::GenerateStrategies(int32_t stage_id) { |
|
|
|
Shape input0_split(inputs_shape_[0].size(), 1); |
|
|
|
// the second dimension is not splitable |
|
|
|
input0_split[1] = 0; |
|
|
|
Shapes splittable_inputs = {input0_split}; |
|
|
|
|
|
|
|
std::vector<StrategyPtr> sp_vector; |
|
|
|
@@ -97,6 +112,7 @@ Status ReLUV2Info::InferForwardCommunication() { |
|
|
|
|
|
|
|
Status ReLUV2Info::InferTensorMap() { |
|
|
|
Shape tensor_map_index; |
|
|
|
Shape tensor_map_mask; |
|
|
|
size_t size = inputs_shape_.at(0).size(); |
|
|
|
// such as 4: tensor_map_index [3,2,1,0] |
|
|
|
for (size_t i = 0; i < size; ++i) { |
|
|
|
@@ -104,9 +120,12 @@ Status ReLUV2Info::InferTensorMap() { |
|
|
|
} |
|
|
|
|
|
|
|
inputs_tensor_map_.push_back(tensor_map_index); |
|
|
|
// output and mask |
|
|
|
outputs_tensor_map_.push_back(tensor_map_index); |
|
|
|
// output |
|
|
|
outputs_tensor_map_.push_back(tensor_map_index); |
|
|
|
tensor_map_mask = tensor_map_index; |
|
|
|
// mask format NC1HWC0 |
|
|
|
tensor_map_mask.push_back(MAP_NONE); |
|
|
|
outputs_tensor_map_.push_back(tensor_map_mask); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -116,7 +135,7 @@ Status ReLUV2Info::InferTensorInfo() { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
TensorLayout input_layout, output_layout; |
|
|
|
TensorLayout input_layout, output_layout, mask_layout; |
|
|
|
// infer tensor layout |
|
|
|
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed."; |
|
|
|
@@ -129,10 +148,15 @@ Status ReLUV2Info::InferTensorInfo() { |
|
|
|
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (mask_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[1], outputs_shape_[1]) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
TensorInfo output_tensor_info(output_layout); |
|
|
|
TensorInfo mask_tensor_info(mask_layout); |
|
|
|
// output and mask |
|
|
|
outputs_tensor_info_.push_back(output_tensor_info); |
|
|
|
outputs_tensor_info_.push_back(output_tensor_info); |
|
|
|
outputs_tensor_info_.push_back(mask_tensor_info); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
|