|
|
|
@@ -96,6 +96,7 @@ class BatchNormGpuKernel : public GpuKernel { |
|
|
|
} |
|
|
|
|
|
|
|
InitResource(); |
|
|
|
is_train_ = GetAttr<bool>(kernel_node, "is_training"); |
|
|
|
if (is_train_) { |
|
|
|
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; |
|
|
|
} else { |
|
|
|
@@ -133,7 +134,6 @@ class BatchNormGpuKernel : public GpuKernel { |
|
|
|
} |
|
|
|
SetTensorDescriptor(format, shape); |
|
|
|
InitSizeLists(); |
|
|
|
is_train_ = GetAttr<bool>(kernel_node, "is_training"); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -229,8 +229,8 @@ class BatchNormGpuKernel : public GpuKernel { |
|
|
|
} |
|
|
|
|
|
|
|
output_size_list_.push_back(output_size_); // output |
|
|
|
output_size_list_.push_back(reserve_size_); // reserve space |
|
|
|
output_size_list_.push_back(para_size_); // save scale |
|
|
|
output_size_list_.push_back(reserve_size_); // reserve space |
|
|
|
output_size_list_.push_back(para_size_); // save mean |
|
|
|
output_size_list_.push_back(para_size_); // save variance |
|
|
|
|
|
|
|
|