|
|
|
@@ -158,7 +158,7 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const |
|
|
|
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, |
|
|
|
output); |
|
|
|
case BROADCAST_TYPE_ABSGRAD: |
|
|
|
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, |
|
|
|
return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, |
|
|
|
output); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -204,7 +204,7 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const |
|
|
|
case BROADCAST_TYPE_FLOORDIV: |
|
|
|
return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output); |
|
|
|
case BROADCAST_TYPE_ABSGRAD: |
|
|
|
return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output); |
|
|
|
return NoBroadcastOperator<T, S, AbsGradFunc<T, S>>(nums, input0, input1, output); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|