| @@ -33,8 +33,10 @@ bool SigmoidCrossEntropyWithLogitsCPUKernel::Launch(const std::vector<kernel::Ad | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| if (dtype_ == kNumberTypeFloat16) { | if (dtype_ == kNumberTypeFloat16) { | ||||
| LaunchKernel<float16>(inputs, outputs); | LaunchKernel<float16>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeFloat32) { | |||||
| } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat64) { | |||||
| LaunchKernel<float>(inputs, outputs); | LaunchKernel<float>(inputs, outputs); | ||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "input dtype only support float16, float32, float64"; | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -33,8 +33,10 @@ bool SigmoidCrossEntropyWithLogitsGradCPUKernel::Launch(const std::vector<kernel | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| if (dtype_ == kNumberTypeFloat16) { | if (dtype_ == kNumberTypeFloat16) { | ||||
| LaunchKernel<float16>(inputs, outputs); | LaunchKernel<float16>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeFloat32) { | |||||
| } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat64) { | |||||
| LaunchKernel<float>(inputs, outputs); | LaunchKernel<float>(inputs, outputs); | ||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "input dtype only support float16, float32, float64"; | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -242,7 +242,8 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker | |||||
| auto tensor_address = tensor->device_address(); | auto tensor_address = tensor->device_address(); | ||||
| MS_EXCEPTION_IF_NULL(address); | MS_EXCEPTION_IF_NULL(address); | ||||
| MS_EXCEPTION_IF_NULL(tensor); | MS_EXCEPTION_IF_NULL(tensor); | ||||
| if (tensor_address != nullptr && tensor_address != address) { | |||||
| if (tensor_address != nullptr && tensor_address != address && | |||||
| std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != DeviceAddressType::kCPU) { | |||||
| tensor->data_sync(false); | tensor->data_sync(false); | ||||
| } | } | ||||
| if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) { | if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) { | ||||
| @@ -234,7 +234,7 @@ void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector< | |||||
| operator_info << ") "; | operator_info << ") "; | ||||
| } | } | ||||
| operator_info << "is not support."; | operator_info << "is not support."; | ||||
| MS_LOG(EXCEPTION) << operator_info.str(); | |||||
| MS_EXCEPTION(TypeError) << operator_info.str(); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr, | bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr, | ||||