Browse Source

fix bug of hccl kernel info

tags/v0.6.0-beta
WilliamLian 5 years ago
parent
commit
ea9b5468bb
2 changed files with 21 additions and 3 deletions
  1. +20
    -2
      mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc
  2. +1
    -1
      tests/st/pynative/test_pynative_resnet50.py

+ 20
- 2
mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc View File

@@ -16,12 +16,30 @@

#include "kernel/hccl/hccl_kernel_metadata.h"
#include <memory>
#include <set>
#include "utils/utils.h"
#include "kernel/hccl/hcom_util.h"
#include "session/anf_runtime_algorithm.h"

namespace mindspore {
namespace kernel {
namespace {
std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) {
const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0};
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index);
if (op_name != kReduceScatter && op_name != kAllGatherOpName) {
return format;
}
if (format == kOpFormat_FRAC_NZ && AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index).size() <= 2) {
return kOpFormat_DEFAULT;
}
if (kReduceNoSupportedSet.find(format) != kReduceNoSupportedSet.end()) {
return kOpFormat_DEFAULT;
}
return format;
}
} // namespace
void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
const std::vector<TypeId> kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16,
kNumberTypeFloat32, kNumberTypeInt16};
@@ -36,13 +54,13 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index));
inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index));
inputs_type.push_back(type);
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index));
outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index));
outputs_type.push_back(type);
}
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();


+ 1
- 1
tests/st/pynative/test_pynative_resnet50.py View File

@@ -428,5 +428,5 @@ def test_pynative_resnet50():
cost_time = end_time - start_time
print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
if step > 1:
assert cost_time < 0.5
assert cost_time < 0.3

Loading…
Cancel
Save