|
|
|
@@ -117,6 +117,7 @@ Status MatmulDDSInfo::InferDevMatrixShape() { |
|
|
|
dev_matrix_shape_.push_back(1); |
|
|
|
dev_matrix_shape_.push_back(1); |
|
|
|
dev_matrix_shape_.push_back(1); |
|
|
|
dev_matrix_shape_origin_ = dev_matrix_shape_; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -168,22 +169,22 @@ Status MatmulDDSInfo::InferTensorMap() { |
|
|
|
} |
|
|
|
TensorMap output_tensor_map_local_prob; |
|
|
|
// output_tensor_map_local_prob [5, 6, -1, -1, -1, -1, -1] |
|
|
|
for (size_t i = 0; i < dev_matrix_shape_.size(); ++i) { |
|
|
|
for (size_t i = 0; i < dev_matrix_shape_origin_.size(); ++i) { |
|
|
|
if (i == 0) { |
|
|
|
output_tensor_map_local_prob.push_back((int64_t)(dev_matrix_shape_.size() - 2)); |
|
|
|
output_tensor_map_local_prob.push_back((int64_t)(dev_matrix_shape_origin_.size() - 2)); |
|
|
|
} else if (i == 1) { |
|
|
|
output_tensor_map_local_prob.push_back((int64_t)(dev_matrix_shape_.size() - 1)); |
|
|
|
output_tensor_map_local_prob.push_back((int64_t)(dev_matrix_shape_origin_.size() - 1)); |
|
|
|
} else { |
|
|
|
output_tensor_map_local_prob.push_back((int64_t)(MAP_NONE)); |
|
|
|
} |
|
|
|
} |
|
|
|
TensorMap output_tensor_map_global_prob; |
|
|
|
// output_tensor_map_global_prob [5, 6, -1, -1, -1, -1, -1] |
|
|
|
for (size_t i = 0; i < dev_matrix_shape_.size(); ++i) { |
|
|
|
for (size_t i = 0; i < dev_matrix_shape_origin_.size(); ++i) { |
|
|
|
if (i == 0) { |
|
|
|
output_tensor_map_global_prob.push_back((int64_t)(dev_matrix_shape_.size() - 2)); |
|
|
|
output_tensor_map_global_prob.push_back((int64_t)(dev_matrix_shape_origin_.size() - 2)); |
|
|
|
} else if (i == 1) { |
|
|
|
output_tensor_map_global_prob.push_back((int64_t)(dev_matrix_shape_.size() - 1)); |
|
|
|
output_tensor_map_global_prob.push_back((int64_t)(dev_matrix_shape_origin_.size() - 1)); |
|
|
|
} else { |
|
|
|
output_tensor_map_global_prob.push_back((int64_t)(MAP_NONE)); |
|
|
|
} |
|
|
|
|