|
|
|
@@ -46,9 +46,8 @@ class BatchNormFoldGradGpuKernel : public GpuKernel { |
|
|
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } |
|
|
|
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } |
|
|
|
|
|
|
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, |
|
|
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, |
|
|
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override { |
|
|
|
(void)workspace; |
|
|
|
// 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' |
|
|
|
T *d_batch_mean = GetDeviceAddress<T>(inputs, 0); |
|
|
|
T *d_batch_std = GetDeviceAddress<T>(inputs, 1); |
|
|
|
@@ -139,11 +138,8 @@ class BatchNormFoldGradGpuKernel : public GpuKernel { |
|
|
|
input_size_list_.push_back(channel_size_); |
|
|
|
input_size_list_.push_back(channel_size_); |
|
|
|
input_size_list_.push_back(sizeof(int)); |
|
|
|
|
|
|
|
// 'dx' |
|
|
|
output_size_list_.push_back(input_size_); |
|
|
|
|
|
|
|
workspace_size_list_.push_back(workspace_size_); |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
|