Browse Source

fix unnecessary sync

tags/v1.1.0
baihuawei 5 years ago
parent
commit
e6fb4b9f69
4 changed files with 9 additions and 4 deletions
  1. +3
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/sigmoid_cross_entropy_with_logits_cpu_kernel.cc
  2. +3
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/sigmoid_cross_entropy_with_logits_grad_cpu_kernel.cc
  3. +2
    -1
      mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc
  4. +1
    -1
      mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc

+ 3
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/sigmoid_cross_entropy_with_logits_cpu_kernel.cc View File

@@ -33,8 +33,10 @@ bool SigmoidCrossEntropyWithLogitsCPUKernel::Launch(const std::vector<kernel::Ad
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat64) {
LaunchKernel<float>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "input dtype only support float16, float32, float64";
}
return true;
}


+ 3
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/sigmoid_cross_entropy_with_logits_grad_cpu_kernel.cc View File

@@ -33,8 +33,10 @@ bool SigmoidCrossEntropyWithLogitsGradCPUKernel::Launch(const std::vector<kernel
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat64) {
LaunchKernel<float>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "input dtype only support float16, float32, float64";
}
return true;
}


+ 2
- 1
mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc View File

@@ -242,7 +242,8 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker
auto tensor_address = tensor->device_address();
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(tensor);
if (tensor_address != nullptr && tensor_address != address) {
if (tensor_address != nullptr && tensor_address != address &&
std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != DeviceAddressType::kCPU) {
tensor->data_sync(false);
}
if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) {


+ 1
- 1
mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc View File

@@ -234,7 +234,7 @@ void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector<
operator_info << ") ";
}
operator_info << "is not support.";
MS_LOG(EXCEPTION) << operator_info.str();
MS_EXCEPTION(TypeError) << operator_info.str();
}
} // namespace
bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,


Loading…
Cancel
Save