|
|
|
@@ -92,6 +92,7 @@ void Bucket::CalculateMean() { |
|
|
|
MS_EXCEPTION_IF_NULL(parallel_context); |
|
|
|
auto grad_mean = parallel_context->gradients_mean(); |
|
|
|
if (!grad_mean) { |
|
|
|
UpdateTensorOutputAddr(ar_output_addr_); |
|
|
|
return; |
|
|
|
} |
|
|
|
if (launch_mul_ == nullptr) { |
|
|
|
@@ -102,12 +103,16 @@ void Bucket::CalculateMean() { |
|
|
|
launch_mul_->SetInputAddr(ar_output_addr_); |
|
|
|
// launch mean |
|
|
|
launch_mul_->LaunchOpKernel(); |
|
|
|
// store output tensor addr |
|
|
|
// store tensor output addr |
|
|
|
auto launch_output = launch_mul_->GetKernelOutputAddr(); |
|
|
|
if (launch_output.size() != 1) { |
|
|
|
MS_LOG(ERROR) << "launch mul outputs should have one output"; |
|
|
|
MS_LOG(EXCEPTION) << "launch mul outputs should have one output"; |
|
|
|
} |
|
|
|
uint8_t *tensor_output = launch_output[0]; |
|
|
|
UpdateTensorOutputAddr(launch_output[0]); |
|
|
|
} |
|
|
|
|
|
|
|
void Bucket::UpdateTensorOutputAddr(uint8_t *addr) { |
|
|
|
uint8_t *tensor_output = addr; |
|
|
|
for (size_t i = 0; i < bucket_size_; ++i) { |
|
|
|
new_tensor_output_addrs_.emplace_back(tensor_output); |
|
|
|
tensor_output += align_size_list_[i]; |
|
|
|
|