Browse Source

don't set parameter's format when it's has been setted before

tags/v0.5.0-beta
WilliamLian 5 years ago
parent
commit
390b4f847a
1 changed files with 11 additions and 1 deletions
  1. +11
    -1
      mindspore/ccsrc/device/ascend/kernel_select_ascend.cc

+ 11
- 1
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc View File

@@ -162,8 +162,18 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
}
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
bool is_ref = false;
auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
if (op_info != nullptr) {
is_ref = op_info->is_ref();
}
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
continue;
}
// we set special device info of a input tensor.
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) {
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};


Loading…
Cancel
Save