Browse Source

!9666 fix hccl count bug

From: @jojobugfree
Reviewed-by: @chujinjin,@zhoufeng54
Signed-off-by: @chujinjin
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a9fcea2ca6
1 changed files with 16 additions and 3 deletions
  1. +16
    -3
      mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc

+ 16
- 3
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc View File

@@ -18,9 +18,16 @@
#include <memory>
#include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/ms_context.h"
#include "utils/utils.h"

namespace mindspore {
bool IsPyNativeMode() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
}

bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_intput_shape_list) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list);
@@ -129,10 +136,16 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp
block_size = input_size / LongToSize(rank_size);
total_size = total_size + block_size;
} else {
if (i == size - 1) {
block_size = input_size;
if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) {
auto cnode = anf_node->cast<CNodePtr>();
if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion)) {
block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
} else {
block_size = input_size;
}
} else {
block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
block_size =
IsPyNativeMode() ? input_size : (input_size + align_size - 1 + filled_size) / align_size * align_size;
}
total_size = total_size + block_size;
}


Loading…
Cancel
Save