|
|
|
@@ -133,29 +133,32 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co |
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { |
|
|
|
auto input_kernel_node = kernel_node->input(input_index + 1); |
|
|
|
MS_EXCEPTION_IF_NULL(input_kernel_node); |
|
|
|
if (!input_kernel_node->isa<Parameter>()) { |
|
|
|
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); |
|
|
|
MS_EXCEPTION_IF_NULL(input_with_index.first); |
|
|
|
auto real_input_node = input_with_index.first; |
|
|
|
if (!real_input_node->isa<Parameter>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder = |
|
|
|
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); |
|
|
|
|
|
|
|
auto param = input_kernel_node->cast<ParameterPtr>(); |
|
|
|
auto param = real_input_node->cast<ParameterPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(param); |
|
|
|
if (!AnfAlgo::IsParameterWeight(param)) { |
|
|
|
std::vector<std::string> output_format = {kOpFormat_DEFAULT}; |
|
|
|
builder->SetOutputsFormat(output_format); |
|
|
|
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)}; |
|
|
|
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; |
|
|
|
builder->SetOutputsDeviceType(output_type); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if ((AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) || |
|
|
|
if ((AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) || |
|
|
|
(AnfAlgo::GetCNodeName(kernel_node) == "ApplyMomentum")) { |
|
|
|
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; |
|
|
|
builder->SetOutputsFormat(output_format); |
|
|
|
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; |
|
|
|
builder->SetOutputsDeviceType(output_type); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|