diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index 58ab65138c..c335ec3a46 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -354,13 +354,15 @@ void SetCastAndWeightFormat(const CNodePtr &kernel_node) { } void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vector &output_format, - const CNodePtr &kernel_node, size_t input_index) { + const CNodePtr &kernel_node, size_t input_index, bool force_fresh = false) { + if (real_input_node->isa() || AnfAlgo::OutputAddrExist(real_input_node, 0)) { + return; + } auto builder = std::make_shared(); // we set special device info of a input tensor. - bool is_ref = false; auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node); if (op_info != nullptr) { - is_ref = op_info->is_ref(); + force_fresh = op_info->is_ref() || force_fresh; } auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node); if (IsValueNode(real_input_node) && @@ -371,7 +373,7 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vectorBuild(), real_input_node.get()); return; } - if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { + if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || force_fresh) { builder->SetOutputsFormat(output_format); std::vector output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; builder->SetOutputsDeviceType(output_type); @@ -381,6 +383,9 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vectorget_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { + return false; + } if (!input_node->isa()) { return false; } @@ -397,7 +402,7 @@ bool RefreshCastAndParamWeightFormat(const AnfNodePtr &input_node, const string info_builder->SetOutputsFormat({format}); AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get()); auto cast_input_node = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cast_node, 0), 0); - SetWeightFormat(cast_input_node.first, {format}, cast_node, 0); + SetWeightFormat(cast_input_node.first, {format}, cast_node, 0, true); return true; } } // namespace @@ -418,9 +423,6 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { if (real_input_node->isa() && !AnfAlgo::IsParameterWeight(real_input_node->cast())) { continue; } - if (AnfAlgo::OutputAddrExist(real_input_node, 0)) { - continue; - } auto refresh_format = selected_kernel_info->GetInputFormat(input_index); std::vector output_format = {refresh_format}; // if not find in host convert format map means the host has not registered the convert function of this format