|
|
@@ -32,7 +32,7 @@ namespace parallel { |
|
|
* prelu has 2 input |
|
|
* prelu has 2 input |
|
|
* A: A float tensor of shape [NCHW] representing the output of the preview layer. |
|
|
* A: A float tensor of shape [NCHW] representing the output of the preview layer. |
|
|
* w: Float Tensor, w > 0: there is only two shapes are legitimate: 1, or the number of channels at input. |
|
|
* w: Float Tensor, w > 0: there is only two shapes are legitimate: 1, or the number of channels at input. |
|
|
* the strategy of w should equal to the channel dimension of strategy of A |
|
|
|
|
|
|
|
|
* the strategy of w should equal to the channel dimension of strategy of A, or equal to 1 |
|
|
*/ |
|
|
*/ |
|
|
Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { |
|
|
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { |
|
|
@@ -52,7 +52,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
} |
|
|
} |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0]) { |
|
|
|
|
|
|
|
|
if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0] && inputs_shape_[1][0] != 1) { |
|
|
if (is_auto_parallel_) { |
|
|
if (is_auto_parallel_) { |
|
|
MS_LOG(DEBUG) << name_ << ": Invalid channel strategy."; |
|
|
MS_LOG(DEBUG) << name_ << ": Invalid channel strategy."; |
|
|
} else { |
|
|
} else { |
|
|
@@ -107,7 +107,11 @@ Status PReLUInfo::InferTensorMap() { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
TensorMap param_tensor_map; |
|
|
TensorMap param_tensor_map; |
|
|
param_tensor_map.push_back(input_tensor_map.at(1)); |
|
|
|
|
|
|
|
|
if (inputs_shape_[1][0] == 1) { |
|
|
|
|
|
param_tensor_map.push_back(-1); |
|
|
|
|
|
} else { |
|
|
|
|
|
param_tensor_map.push_back(input_tensor_map.at(1)); |
|
|
|
|
|
} |
|
|
inputs_tensor_map_.push_back(input_tensor_map); |
|
|
inputs_tensor_map_.push_back(input_tensor_map); |
|
|
inputs_tensor_map_.push_back(param_tensor_map); |
|
|
inputs_tensor_map_.push_back(param_tensor_map); |
|
|
outputs_tensor_map_.push_back(input_tensor_map); |
|
|
outputs_tensor_map_.push_back(input_tensor_map); |
|
|
|