|
|
|
@@ -33,8 +33,10 @@ bool SigmoidCrossEntropyWithLogitsCPUKernel::Launch(const std::vector<kernel::Ad |
|
|
|
const std::vector<kernel::AddressPtr> &outputs) { |
|
|
|
if (dtype_ == kNumberTypeFloat16) { |
|
|
|
LaunchKernel<float16>(inputs, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeFloat32) { |
|
|
|
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat64) { |
|
|
|
LaunchKernel<float>(inputs, outputs); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "input dtype only support float16, float32, float64"; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|