|
|
|
@@ -52,6 +52,11 @@ bool CPUDeviceAddress::SyncDeviceToHost(const std::vector<int> & /*shape*/, size |
|
|
|
|
|
|
|
bool CPUDeviceAddress::SyncHostToDevice(const std::vector<int> & /*shape*/, size_t size, TypeId type, |
|
|
|
const void *host_ptr) const { |
|
|
|
if (host_ptr == ptr_) { |
|
|
|
MS_LOG(DEBUG) << "host_ptr is equal to ptr_, request ignored."; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
if (type == kNumberTypeFloat16) { |
|
|
|
HalfToFloat(ptr_, host_ptr, size / 2); |
|
|
|
} else if (type == kNumberTypeFloat64) { |
|
|
|
|