|
|
|
@@ -64,8 +64,6 @@ void Bucket::Launch() { |
|
|
|
MS_LOG(INFO) << "Bucket launch cost:" << (GetTime() - start) * 1e6 << " us"; |
|
|
|
} |
|
|
|
|
|
|
|
// TODO(caifubi): float16 grad cast to float32 grad |
|
|
|
|
|
|
|
void Bucket::UpdateTensorAddr() { |
|
|
|
if (grad_tensor_list_.size() != bucket_size_ || new_tensor_output_addrs_.size() != bucket_size_) { |
|
|
|
MS_LOG(EXCEPTION) << "grad_tensor_list size:" << grad_tensor_list_.size() |
|
|
|
@@ -80,7 +78,6 @@ void Bucket::UpdateTensorAddr() { |
|
|
|
// release old addr and manage addr by this Bucket. |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
auto origin_dev_ptr = device_address->GetMutablePtr(); |
|
|
|
// FreeDeviceMem(origin_dev_ptr); |
|
|
|
tensor_old_addr_list_.emplace_back(origin_dev_ptr); |
|
|
|
device_address->from_mem_pool_ = false; |
|
|
|
device_address->set_ptr(new_tensor_output_addrs_[i]); |
|
|
|
|