|
|
@@ -114,7 +114,7 @@ class BatchNormFold2GradGpuKernel : public GpuKernel { |
|
|
channel_ = input_shape[1]; |
|
|
channel_ = input_shape[1]; |
|
|
height_ = input_shape[2]; |
|
|
height_ = input_shape[2]; |
|
|
width_ = input_shape[3]; |
|
|
width_ = input_shape[3]; |
|
|
freeze_bn_ = GetValue<int32_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); |
|
|
|
|
|
|
|
|
freeze_bn_ = GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); |
|
|
|
|
|
|
|
|
InitSizeLists(); |
|
|
InitSizeLists(); |
|
|
return true; |
|
|
return true; |
|
|
|