|
|
|
@@ -110,7 +110,13 @@ void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, |
|
|
|
|
|
|
|
template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const float *x1, const float *x2, |
|
|
|
const float *dy, float *dx1, float *dx2, cudaStream_t stream); |
|
|
|
template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const int *x1, const int *x2, |
|
|
|
const int *dy, int *dx1, int *dx2, cudaStream_t stream); |
|
|
|
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, |
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, |
|
|
|
enum BroadcastGradOpType op, const float *x1, const float *x2, const float *dy, float *dx1, |
|
|
|
float *dx2, cudaStream_t stream); |
|
|
|
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, |
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, |
|
|
|
enum BroadcastGradOpType op, const int *x1, const int *x2, const int *dy, int *dx1, |
|
|
|
int *dx2, cudaStream_t stream); |