Browse Source

Clean code

r1.7
tanghuikang 4 years ago
parent
commit
dbcdcc2daa
23 changed files with 56 additions and 61 deletions
  1. +5
    -2
      mindspore/ccsrc/backend/common/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.cc
  2. +1
    -1
      mindspore/ccsrc/backend/common/somas/somas_solver_pre.cc
  3. +4
    -3
      mindspore/ccsrc/kernel/kash/kernel_pack.cc
  4. +1
    -1
      mindspore/ccsrc/kernel/kernel.h
  5. +1
    -1
      mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc
  6. +1
    -1
      mindspore/ccsrc/plugin/device/ascend/hal/device/executor/hccl_dynamic_kernel.cc
  7. +1
    -1
      mindspore/ccsrc/plugin/device/ascend/hal/device/ge_runtime/task/profiler_task.cc
  8. +0
    -2
      mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc
  9. +23
    -24
      mindspore/ccsrc/plugin/device/ascend/hal/hccl_adapter/hccl_adapter.cc
  10. +1
    -0
      mindspore/ccsrc/plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h
  11. +2
    -2
      mindspore/ccsrc/plugin/device/ascend/kernel/rts/assign.cc
  12. +2
    -2
      mindspore/ccsrc/plugin/device/ascend/kernel/rts/label_goto.cc
  13. +2
    -3
      mindspore/ccsrc/plugin/device/ascend/kernel/rts/profiling_kernel_mod.cc
  14. +0
    -1
      mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/fusion_tbe_json_creator.cc
  15. +1
    -1
      mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc
  16. +3
    -3
      mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_dropoutdomaskv3_fusion_pass.cc
  17. +1
    -2
      mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_dropoutdomaskv3_fusion_pass.h
  18. +1
    -2
      mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_dropoutdomaskv3_add_fusion_pass.cc
  19. +1
    -2
      mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_dropoutdomaskv3_add_fusion_pass.h
  20. +1
    -2
      mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_eltwise_fusion_pass.cc
  21. +1
    -2
      mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_eltwise_fusion_pass.h
  22. +1
    -1
      mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_kernel_runtime.cc
  23. +2
    -2
      mindspore/ccsrc/runtime/device/kernel_runtime.cc

+ 5
- 2
mindspore/ccsrc/backend/common/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.cc View File

@@ -20,6 +20,7 @@

namespace mindspore {
namespace opt {
constexpr const int64_t kFusionGap = 2;
bool AdjustDependForParallelOptimizerRecomputeAllGather::Run(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
mindspore::HashMap<int64_t, bool> forward_allgather_recompute_value_in_fusion_group;
@@ -81,12 +82,14 @@ void AdjustDependForParallelOptimizerRecomputeAllGather::IncreaseAllgatherFusion
MS_LOG(WARNING) << "Increase the duplicated allgather fusion id";
for (auto &adjust_node : parallel_optimizer_recompute_first_fusion_allgathers) {
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
int64_t destination_fusion_id = current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + 2;
int64_t destination_fusion_id =
current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + kFusionGap;
common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
}
for (auto &adjust_node : parallel_optimizer_recompute_allgathers) {
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
int64_t destination_fusion_id = current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + 2;
int64_t destination_fusion_id =
current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + kFusionGap;
common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
}
}


+ 1
- 1
mindspore/ccsrc/backend/common/somas/somas_solver_pre.cc View File

@@ -71,7 +71,7 @@ Status SomasSolverPre::AddContiguousInfoInMultiMaps(const vector<vector<size_t>>
for (size_t i = 0; i < aux.size() - 1; i++) {
auto index1 = aux[i];
auto index2 = aux[i + 1];
if (CheckTensors(pTensors, index1, index2) == FAILED) {
if (CheckTensors(pTensors, SizeToUint(index1), SizeToUint(index2)) == FAILED) {
return FAILED;
}
for (size_t sol = 0; sol < vecTensorsMap->size(); sol++) {


+ 4
- 3
mindspore/ccsrc/kernel/kash/kernel_pack.cc View File

@@ -25,6 +25,7 @@
#include "kernel/common_utils.h"
namespace mindspore {
namespace kernel {
constexpr size_t kJsonSuffixLength = 5;
namespace {
bool CheckHash(const std::string &json_file, const std::string &bin_file, const nlohmann::json &js) {
if (js.find("sha256") == js.end()) {
@@ -108,7 +109,7 @@ bool KernelPack::ReadFromJsonFile(const std::string &json_f, const std::string &
}

if (processor == kProcessorCuda) {
std::string bin_f = json_f.substr(0, json_f.length() - 5) + ".ptx";
std::string bin_f = json_f.substr(0, json_f.length() - kJsonSuffixLength) + ".ptx";
std::ifstream kernelbin(bin_f);
if (!kernelbin.is_open()) {
MS_LOG(ERROR) << "read kernel ptx file error, please check kernelmeta.";
@@ -140,7 +141,7 @@ bool KernelPack::ReadFromJsonFile(const std::string &json_f, const std::string &
}

std::string binfile_suffix = js["binFileSuffix"];
std::string bin_f = json_f.substr(0, json_f.length() - 5) + binfile_suffix;
std::string bin_f = json_f.substr(0, json_f.length() - kJsonSuffixLength) + binfile_suffix;
if (binfile_suffix == ".so") {
// change "xx/xx.so" -> "xx/libxx.so"
auto sp = bin_f.rfind('/');
@@ -234,7 +235,7 @@ bool KernelPack::LoadKernelMeta(const std::string &json_f) {
}
ParseKernelJson(js);

std::string bin_f = json_f.substr(0, json_f.length() - 5) + kernel_json_info_.bin_file_suffix;
std::string bin_f = json_f.substr(0, json_f.length() - kJsonSuffixLength) + kernel_json_info_.bin_file_suffix;
if (kernel_json_info_.bin_file_suffix == ".so") {
// change "xx/xx.so" -> "xx/libxx.so"
auto sp = bin_f.rfind('/');


+ 1
- 1
mindspore/ccsrc/kernel/kernel.h View File

@@ -187,7 +187,7 @@ class KernelMod {
explicit KernelMod(const AnfNodePtr &anf_node_ptr) : anf_node_(anf_node_ptr) {}
virtual ~KernelMod() = default;

bool Launch(const KernelLaunchInfo &kernel_launch_address, void *stream_ptr) {
bool LaunchKernel(const KernelLaunchInfo &kernel_launch_address, void *stream_ptr) {
return Launch(kernel_launch_address.inputs_, kernel_launch_address.workspaces_, kernel_launch_address.outputs_,
stream_ptr);
}


+ 1
- 1
mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc View File

@@ -365,7 +365,7 @@ bool AscendKernelRuntime::Init() {
return true;
}

bool AscendKernelRuntime::LoadData(const session::KernelGraph & /* graph */) {
bool AscendKernelRuntime::LoadData(const session::KernelGraph &) {
#ifdef ENABLE_DEBUGGER
MS_LOG(INFO) << "Start load step";
MS_EXCEPTION_IF_NULL(debugger_);


+ 1
- 1
mindspore/ccsrc/plugin/device/ascend/hal/device/executor/hccl_dynamic_kernel.cc View File

@@ -81,7 +81,7 @@ void HcclDynamicKernel::StaticShapeExecute() {
MS_EXCEPTION_IF_NULL(kernel_mod);
KernelLaunchInfo kernel_launch_info;
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_launch_info);
kernel_mod->Launch(kernel_launch_info, stream_);
kernel_mod->LaunchKernel(kernel_launch_info, stream_);
}

void HcclDynamicKernel::Execute() {


+ 1
- 1
mindspore/ccsrc/plugin/device/ascend/hal/device/ge_runtime/task/profiler_task.cc View File

@@ -31,7 +31,7 @@ ProfilerTask::ProfilerTask(const ModelContext &model_context, const std::shared_
stream_ = stream_list[stream_id];
}

ProfilerTask::~ProfilerTask() {}
ProfilerTask::~ProfilerTask() { stream_ = nullptr; }

void ProfilerTask::Distribute() {
MS_LOG(INFO) << "ProfilerTask Distribute start.";


+ 0
- 2
mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc View File

@@ -298,8 +298,6 @@ void AscendDeviceContext::Destroy() {
graph_event_.clear();
rank_id_ = 0;
if (runtime_instance_) {
// TODO(lzlang): Destroy runtime instance after fully support MindRT, otherwise runtime will be destructed
// repeatedly. runtime_instance_->ReleaseDeviceRes();
runtime_instance_ = nullptr;
}
AscendGraphOptimization::GetInstance().Reset();


+ 23
- 24
mindspore/ccsrc/plugin/device/ascend/hal/hccl_adapter/hccl_adapter.cc View File

@@ -40,17 +40,6 @@ static constexpr const char *kHcclAlgoOption = "HCCL_algorithm";
return HcclResult::HCCL_E_RESERVED; \
}

#define CHECK_EXCUTION_MODE() \
do { \
auto hccl_mode = GetCurrentHcclMode(); \
if (hccl_mode != hccl_mode_) { \
MS_LOG(EXCEPTION) << "HCCL is initialized in " << GetHcclModeString(hccl_mode_) \
<< " but current execution mode is " << GetHcclModeString(hccl_mode) \
<< ". Please set the execution mode before HCCL init(), and then do not " \
"change it in the subsequent script"; \
} \
} while (0)

static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id,
std::string_view rank_file) {
auto env_deploy_mode = mindspore::common::GetEnv(kHcclDeployModeEnv);
@@ -159,6 +148,16 @@ HcclMode HcclAdapter::GetCurrentHcclMode() const {
}
}

void HcclAdapter::CheckExcutionMode() const {
auto hccl_mode = GetCurrentHcclMode();
if (hccl_mode != hccl_mode_) {
MS_LOG(EXCEPTION) << "HCCL is initialized in " << GetHcclModeString(hccl_mode_) << " but current execution mode is "
<< GetHcclModeString(hccl_mode)
<< ". Please set the execution mode before HCCL init(), and then do not change it in the "
"subsequent script";
}
}

std::string HcclAdapter::GetHcclModeString(HcclMode hccl_mode) {
static std::map<HcclMode, std::string> kHcclModeString = {
{HcclMode::kGraph, "GRAPH_MODE"},
@@ -307,14 +306,14 @@ std::string HcclAdapter::GetHcclType(const AnfNodePtr &node) {

HcclResult HcclAdapter::HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root,
aclrtStream stream) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(launch_hccl_broadcast_);
return launch_hccl_broadcast_(buf, count, dataType, root, hccl_comm_, stream);
}

HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
HcclReduceOp op, aclrtStream stream, const std::string &group) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(launch_hccl_all_reduce_);
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
@@ -323,7 +322,7 @@ HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t c

HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
HcclReduceOp op, aclrtStream stream, const std::string &group) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(launch_hccl_reduce_scatter_);
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
@@ -332,7 +331,7 @@ HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64

HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
aclrtStream stream, const std::string &group) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(launch_hccl_all_gather_);
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
@@ -341,7 +340,7 @@ HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t c

HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank,
aclrtStream stream, const std::string &group) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(launch_hccl_send_);
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
@@ -350,7 +349,7 @@ HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType da

HcclResult HcclAdapter::HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank,
aclrtStream stream, const std::string &group) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(launch_hccl_recv_);
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
@@ -474,7 +473,7 @@ bool HcclAdapter::FinalizeHcclComm() {
}

HcclResult HcclAdapter::HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(hccl_create_group_);
return hccl_create_group_(group.c_str(), rank_num, rank_ids);
}
@@ -485,25 +484,25 @@ HcclResult HcclAdapter::HcclDestroyGroup(const std::string &group) const {
}

HcclResult HcclAdapter::HcclGetRankId(uint32_t *rank_id) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(single_op_hccl_get_rank_id_);
return single_op_hccl_get_rank_id_(hccl_comm_, rank_id);
}

HcclResult HcclAdapter::HcclGetRankSize(uint32_t *rank_size) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(single_op_hccl_get_rank_size_);
return single_op_hccl_get_rank_size_(hccl_comm_, rank_size);
}

HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rank_id) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(hccl_get_rank_id_);
return hccl_get_rank_id_(group.c_str(), rank_id);
}

HcclResult HcclAdapter::HcclGetRankSize(const std::string &group, uint32_t *rank_size) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(hccl_get_rank_size_);
return hccl_get_rank_size_(group.c_str(), rank_size);
}
@@ -537,13 +536,13 @@ bool HcclAdapter::FinalizeHcclExec() {
}

HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(hccl_exec_enqueue_op_);
return hccl_exec_enqueue_op_(op_info, callback);
}

HcclResult HcclAdapter::HcclExecAllToAllv(const ::HcomAllToAllVParams &params, const HExecCallBack &callback) const {
CHECK_EXCUTION_MODE();
CheckExcutionMode();
CHECK_SYMBOL_NULL(hccl_exec_enqueue_all_to_all_v_);
return hccl_exec_enqueue_all_to_all_v_(params, callback);
}


+ 1
- 0
mindspore/ccsrc/plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h View File

@@ -106,6 +106,7 @@ class HcclAdapter {
bool FinalizeHcclExec();

HcclMode GetCurrentHcclMode() const;
void CheckExcutionMode() const;
static std::string GetHcclModeString(HcclMode hccl_mode);

void *plugin_handle_ = nullptr;


+ 2
- 2
mindspore/ccsrc/plugin/device/ascend/kernel/rts/assign.cc View File

@@ -28,8 +28,8 @@ AssignKernel::AssignKernel() {}

AssignKernel::~AssignKernel() {}

bool AssignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /* workspace */,
const std::vector<AddressPtr> & /*outputs*/, void *stream_ptr) {
bool AssignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *stream_ptr) {
if (inputs.size() != 2) {
MS_LOG(ERROR) << "inputs size is not two";
return false;


+ 2
- 2
mindspore/ccsrc/plugin/device/ascend/kernel/rts/label_goto.cc View File

@@ -43,8 +43,8 @@ bool LabelGotoKernel::Init(const AnfNodePtr &anf_node) {
return true;
}

bool LabelGotoKernel::Launch(const std::vector<AddressPtr> & /*inputs*/, const std::vector<AddressPtr> & /*workspace*/,
const std::vector<AddressPtr> & /*outputs*/, void * /*stream_ptr*/) {
bool LabelGotoKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *) {
MS_LOG(INFO) << "LabelGotoKernel launch";
return true;
}


+ 2
- 3
mindspore/ccsrc/plugin/device/ascend/kernel/rts/profiling_kernel_mod.cc View File

@@ -50,9 +50,8 @@ bool ProfilingKernelMod::Init(const AnfNodePtr &anf_node) {
return true;
}

bool ProfilingKernelMod::Launch(const std::vector<AddressPtr> & /*inputs*/,
const std::vector<AddressPtr> & /*workspace*/,
const std::vector<AddressPtr> & /*outputs*/, void * /*stream_ptr*/) {
bool ProfilingKernelMod::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *) {
return true;
}



+ 0
- 1
mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/fusion_tbe_json_creator.cc View File

@@ -213,7 +213,6 @@ bool FusionBuildTbeJsonCreator::GenInputsJson(const AnfNodePtr &anf_node, nlohma
input_desc_list_tmp.emplace_back(optional_input_desc);
}
std::vector<nlohmann::json> input_desc_list;
// TODO(jjf): error when reordered op have input not in input_nodes.
TbeAdapter::InputOrderPass<nlohmann::json>(cnode, input_desc_list_tmp, &input_desc_list);
(*compute_json)[kJInputDesc] = input_desc_list;
return true;


+ 1
- 1
mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc View File

@@ -180,7 +180,7 @@ void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) {
void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) {
auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_);
SupportFormat support_format;
reduce_selecter.GetShapeInfo(&support_format);
(void)reduce_selecter.GetShapeInfo(&support_format);
(void)reduce_selecter.IsReduceSupport5HD(&support_format);
(void)reduce_selecter.IsReduceSupportFracZ(&support_format);
(void)reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format);


+ 3
- 3
mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_dropoutdomaskv3_fusion_pass.cc View File

@@ -24,8 +24,8 @@

namespace mindspore {
namespace opt {
void BatchMatmulDropoutDoMaskV3FusionPass::MatchBatchMatmulDropoutDoMaskV3(
const CNodePtr &cnode, const session::KernelGraph & /* kernel_graph */, FusedNodeRecord *candidate_fusion) {
void BatchMatmulDropoutDoMaskV3FusionPass::MatchBatchMatmulDropoutDoMaskV3(const CNodePtr &cnode,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto batch_matmul = cnode->input(1);
@@ -50,7 +50,7 @@ void BatchMatmulDropoutDoMaskV3FusionPass::MatchSingleFusionPattern(const sessio
MS_EXCEPTION_IF_NULL(cnode);

if (common::AnfAlgo::GetCNodeName(cnode) == kDropoutDoMaskV3OpName) {
MatchBatchMatmulDropoutDoMaskV3(cnode, kernel_graph, candidate_fusion);
MatchBatchMatmulDropoutDoMaskV3(cnode, candidate_fusion);
}
}
}


+ 1
- 2
mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_dropoutdomaskv3_fusion_pass.h View File

@@ -35,8 +35,7 @@ class BatchMatmulDropoutDoMaskV3FusionPass : public FusionBasePass {
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

private:
void MatchBatchMatmulDropoutDoMaskV3(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion);
void MatchBatchMatmulDropoutDoMaskV3(const CNodePtr &cnode, FusedNodeRecord *candidate_fusion);
};
} // namespace opt
} // namespace mindspore


+ 1
- 2
mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_dropoutdomaskv3_add_fusion_pass.cc View File

@@ -25,7 +25,6 @@
namespace mindspore {
namespace opt {
void MatmulDropoutDoMaskV3AddFusionPass::MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode,
const session::KernelGraph & /* kernel_graph */,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
@@ -59,7 +58,7 @@ void MatmulDropoutDoMaskV3AddFusionPass::MatchSingleFusionPattern(const session:
MS_EXCEPTION_IF_NULL(cnode);

if (common::AnfAlgo::GetCNodeName(cnode) == kAddOpName) {
MatchMatmulDropoutDoMaskV3Add(cnode, kernel_graph, candidate_fusion);
MatchMatmulDropoutDoMaskV3Add(cnode, candidate_fusion);
}
}
}


+ 1
- 2
mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_dropoutdomaskv3_add_fusion_pass.h View File

@@ -35,8 +35,7 @@ class MatmulDropoutDoMaskV3AddFusionPass : public FusionBasePass {
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

private:
void MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion);
void MatchMatmulDropoutDoMaskV3Add(const CNodePtr &cnode, FusedNodeRecord *candidate_fusion);
};
} // namespace opt
} // namespace mindspore


+ 1
- 2
mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_eltwise_fusion_pass.cc View File

@@ -24,7 +24,6 @@
namespace mindspore {
namespace opt {
void MatmulEltwiseFusionPass::MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input,
const session::KernelGraph & /* kernel_graph */,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
@@ -62,7 +61,7 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (eltwise_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) {
MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion);
MatchMatmulEltwise(cnode, eltwise_input, candidate_fusion);
}
}
}


+ 1
- 2
mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_eltwise_fusion_pass.h View File

@@ -37,8 +37,7 @@ class MatmulEltwiseFusionPass : public FusionBasePass {
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

private:
void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion);
void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, FusedNodeRecord *candidate_fusion);
};
} // namespace opt
} // namespace mindspore


+ 1
- 1
mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_kernel_runtime.cc View File

@@ -903,7 +903,7 @@ bool GPUKernelRuntime::RunOpLaunchKernelDynamic(const session::KernelGraph *grap
KernelLaunchInfo kernel_launch_info;
GenLaunchArgs(*kernel_mod, kernel, &kernel_launch_info);
MS_EXCEPTION_IF_NULL(stream_);
auto ret = kernel_mod->Launch(kernel_launch_info, stream_);
auto ret = kernel_mod->LaunchKernel(kernel_launch_info, stream_);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";
return false;


+ 2
- 2
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -1294,7 +1294,7 @@ bool KernelRuntime::LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_
start->set_record_stream(stream);
end->set_record_stream(stream);
start->RecordEvent();
bool ret = kernel_mod->Launch(kernel_launch_info, stream);
bool ret = kernel_mod->LaunchKernel(kernel_launch_info, stream);
if (!ret) {
MS_LOG(EXCEPTION) << "Launch kernel failed, kernel name is : " << op_name;
}
@@ -1523,7 +1523,7 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
if (pynative_mode_profiling_flag_) {
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_launch_info, stream);
} else {
ret = kernel_mod->Launch(kernel_launch_info, stream);
ret = kernel_mod->LaunchKernel(kernel_launch_info, stream);
}
if (!ret) {
return ret;


Loading…
Cancel
Save