|
|
|
@@ -363,9 +363,6 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { |
|
|
|
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (selected_kernel_info->GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// we set special device info of a input tensor. |
|
|
|
bool is_ref = false; |
|
|
|
auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node); |
|
|
|
@@ -376,9 +373,12 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); |
|
|
|
std::vector<std::string> output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)}; |
|
|
|
if (IsValueNode<tensor::Tensor>(input_kernel_node) && |
|
|
|
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { |
|
|
|
std::vector<std::string> output_format = {selected_kernel_info->GetInputFormat(input_index)}; |
|
|
|
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) { |
|
|
|
output_format = {selected_kernel_info->GetInputFormat(input_index)}; |
|
|
|
} |
|
|
|
builder->SetOutputsFormat(output_format); |
|
|
|
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)}; |
|
|
|
builder->SetOutputsDeviceType(output_type); |
|
|
|
@@ -386,7 +386,9 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { |
|
|
|
std::vector<std::string> output_format = {selected_kernel_info->GetInputFormat(input_index)}; |
|
|
|
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) { |
|
|
|
output_format = {selected_kernel_info->GetInputFormat(input_index)}; |
|
|
|
} |
|
|
|
builder->SetOutputsFormat(output_format); |
|
|
|
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)}; |
|
|
|
builder->SetOutputsDeviceType(output_type); |
|
|
|
|