|
|
|
@@ -518,13 +518,30 @@ __kernel void BroadcastNHWC4SquaredDifference(__read_only image2d_t input_a, __r |
|
|
|
int X = get_global_id(0); // C4 |
|
|
|
int Y = get_global_id(1); // w |
|
|
|
int Z = get_global_id(2); // H |
|
|
|
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) { |
|
|
|
|
|
|
|
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(Y * a_shape.w + X, Z)); |
|
|
|
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, 0)); |
|
|
|
FLT4 result = pown((a - b), (int4)2); |
|
|
|
int H = Z % output_shape.y; |
|
|
|
int N = Z / output_shape.y; |
|
|
|
int a_c = X < a_shape.w ? X : 0; |
|
|
|
int a_w = Y < a_shape.z ? Y : 0; |
|
|
|
int a_h = H < a_shape.y ? H : 0; |
|
|
|
int a_n = N < a_shape.x ? N : 0; |
|
|
|
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h)); |
|
|
|
int b_c = X < b_shape.w ? X : 0; |
|
|
|
int b_w = Y < b_shape.z ? Y : 0; |
|
|
|
int b_h = H < b_shape.y ? H : 0; |
|
|
|
int b_n = N < b_shape.x ? N : 0; |
|
|
|
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h)); |
|
|
|
FLT4 result; |
|
|
|
if (broadcastC_flag == 0) { |
|
|
|
result = pown((a - b), (int4)2); |
|
|
|
} else if (broadcastC_flag == 1) { |
|
|
|
result = pown((a.x - b), (int4)2); |
|
|
|
} else { |
|
|
|
result = pown((a - b.x), (int4)2); |
|
|
|
} |
|
|
|
result = clamp(result, (FLT)(act_min), (FLT)(act_max)); |
|
|
|
WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result); |
|
|
|
} |
|
|
|
|