| @@ -354,13 +354,15 @@ void SetCastAndWeightFormat(const CNodePtr &kernel_node) { | |||
| } | |||
| void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vector<string> &output_format, | |||
| const CNodePtr &kernel_node, size_t input_index) { | |||
| const CNodePtr &kernel_node, size_t input_index, bool force_fresh = false) { | |||
| if (real_input_node->isa<CNode>() || AnfAlgo::OutputAddrExist(real_input_node, 0)) { | |||
| return; | |||
| } | |||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| // we set special device info of a input tensor. | |||
| bool is_ref = false; | |||
| auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node); | |||
| if (op_info != nullptr) { | |||
| is_ref = op_info->is_ref(); | |||
| force_fresh = op_info->is_ref() || force_fresh; | |||
| } | |||
| auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node); | |||
| if (IsValueNode<tensor::Tensor>(real_input_node) && | |||
| @@ -371,7 +373,7 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vector<string | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); | |||
| return; | |||
| } | |||
| if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { | |||
| if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || force_fresh) { | |||
| builder->SetOutputsFormat(output_format); | |||
| std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; | |||
| builder->SetOutputsDeviceType(output_type); | |||
| @@ -381,6 +383,9 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vector<string | |||
| bool RefreshCastAndParamWeightFormat(const AnfNodePtr &input_node, const string &format) { | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| return false; | |||
| } | |||
| if (!input_node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| @@ -397,7 +402,7 @@ bool RefreshCastAndParamWeightFormat(const AnfNodePtr &input_node, const string | |||
| info_builder->SetOutputsFormat({format}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get()); | |||
| auto cast_input_node = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cast_node, 0), 0); | |||
| SetWeightFormat(cast_input_node.first, {format}, cast_node, 0); | |||
| SetWeightFormat(cast_input_node.first, {format}, cast_node, 0, true); | |||
| return true; | |||
| } | |||
| } // namespace | |||
| @@ -418,9 +423,6 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { | |||
| if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::OutputAddrExist(real_input_node, 0)) { | |||
| continue; | |||
| } | |||
| auto refresh_format = selected_kernel_info->GetInputFormat(input_index); | |||
| std::vector<std::string> output_format = {refresh_format}; | |||
| // if not find in host convert format map means the host has not registered the convert function of this format | |||