Browse Source

fix AbsGrad bug

tags/v0.7.0-beta
CaoJian 5 years ago
parent
commit
c7403b5aea
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu

+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu View File

@@ -158,7 +158,7 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output); output);
case BROADCAST_TYPE_ABSGRAD: case BROADCAST_TYPE_ABSGRAD:
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output); output);
} }
} }
@@ -204,7 +204,7 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const
case BROADCAST_TYPE_FLOORDIV: case BROADCAST_TYPE_FLOORDIV:
return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output); return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output);
case BROADCAST_TYPE_ABSGRAD: case BROADCAST_TYPE_ABSGRAD:
return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output);
return NoBroadcastOperator<T, S, AbsGradFunc<T, S>>(nums, input0, input1, output);
} }
} }




Loading…
Cancel
Save