| @@ -158,7 +158,7 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const | |||||
| return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | ||||
| output); | output); | ||||
| case BROADCAST_TYPE_ABSGRAD: | case BROADCAST_TYPE_ABSGRAD: | ||||
| return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||||
| return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||||
| output); | output); | ||||
| } | } | ||||
| } | } | ||||
| @@ -204,7 +204,7 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const | |||||
| case BROADCAST_TYPE_FLOORDIV: | case BROADCAST_TYPE_FLOORDIV: | ||||
| return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output); | return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output); | ||||
| case BROADCAST_TYPE_ABSGRAD: | case BROADCAST_TYPE_ABSGRAD: | ||||
| return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output); | |||||
| return NoBroadcastOperator<T, S, AbsGradFunc<T, S>>(nums, input0, input1, output); | |||||
| } | } | ||||
| } | } | ||||