|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -113,6 +113,9 @@ void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, |
|
|
|
NoBroadcastGradKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(nums, grad_x1, grad_x2, op, x1, x2, dy, dx1, dx2); |
|
|
|
} |
|
|
|
|
|
|
|
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, |
|
|
|
const double *x1, const double *x2, const double *dy, double *dx1, double *dx2, |
|
|
|
cudaStream_t stream); |
|
|
|
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, |
|
|
|
const float *x1, const float *x2, const float *dy, float *dx1, float *dx2, |
|
|
|
cudaStream_t stream); |
|
|
|
@@ -124,6 +127,10 @@ template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool & |
|
|
|
template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, |
|
|
|
const int64_t *x1, const int64_t *x2, const int64_t *dy, int64_t *dx1, int64_t *dx2, |
|
|
|
cudaStream_t stream); |
|
|
|
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, |
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, |
|
|
|
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const double *x1, |
|
|
|
const double *x2, const double *dy, double *dx1, double *dx2, cudaStream_t stream); |
|
|
|
template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, |
|
|
|
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, |
|
|
|
const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const float *x1, |
|
|
|
|