Browse Source

initial commit: add nullptr exception in GetDeviceAddress

all cudnn functions now use the new GetPossiblyNullDeviceAddress

fix batchnorm

fix ci

fix nll loss

fix cast and concat

fix cast: skip kernel if null input and output

fix ci

fix concat: allow null input

fix concat: allow for null inputs
tags/v1.5.0-rc1
Peilin Wang 4 years ago
parent
commit
6a1b1495d9
21 changed files with 78 additions and 72 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h
  2. +12
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.h
  3. +19
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
  4. +13
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h
  5. +4
    -6
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h
  6. +4
    -6
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h
  7. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h
  8. +1
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h
  9. +2
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h
  10. +2
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h
  11. +2
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.h
  12. +2
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_filter_gpu_kernel.h
  13. +1
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_input_gpu_kernel.h
  14. +2
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_transpose_gpu_kernel.h
  15. +1
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h
  16. +1
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h
  17. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h
  18. +4
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h
  19. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h
  20. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h
  21. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h View File

@@ -49,7 +49,7 @@ class ArrayReduceGpuKernel : public GpuKernel {
}
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
T *workspace_addr = GetDeviceAddress<T>(workspace, 0);
T *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);

T alpha = static_cast<T>(1.0f);
T beta = static_cast<T>(0.0f);


+ 12
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -36,10 +36,18 @@ class CastGpuKernel : public GpuKernel {

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
S *input_addr = GetDeviceAddress<S>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
S *input_addr = GetPossiblyNullDeviceAddress<S>(inputs, 0);
T *output_addr = GetPossiblyNullDeviceAddress<T>(outputs, 0);

if (input_addr == nullptr && output_addr == nullptr) {
return true;
} else if (input_addr != nullptr && output_addr != nullptr) {
Cast(input_size_, input_addr, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
MS_LOG(EXCEPTION)
<< "The input and output device addresses for CastGpuKernel should be both null or both not null.";
}

Cast(input_size_, input_addr, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}



+ 19
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h View File

@@ -43,11 +43,20 @@ class ConcatV2GpuFwdKernel : public GpuKernel {

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (input_num_ == 0) {
return true;
}

T *output = GetDeviceAddress<T>(outputs, 0);
T **inputs_device = GetDeviceAddress<T *>(workspace, 0);
int *len_axis_device = GetDeviceAddress<int>(workspace, 1);
int current_dim = 0;
for (size_t i = 0; i < inputs.size(); i++) {
inputs_host_[i] = GetDeviceAddress<T>(inputs, i);
T *input = GetPossiblyNullDeviceAddress<T>(inputs, i);
if (input != nullptr) {
inputs_host_[current_dim] = input;
current_dim++;
}
}
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(inputs_device, inputs_host_.get(), sizeof(T *) * input_num_,
@@ -83,14 +92,21 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
input_num_ = SizeToInt(AnfAlgo::GetInputTensorNum(kernel_node));
inputs_host_ = std::make_unique<T *[]>(input_num_);
len_axis_ = std::make_unique<int[]>(input_num_);
int current_dim = 0;
for (int i = 0; i < input_num_; i++) {
size_t input_size = 1;
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
for (size_t j = 0; j < input_shape.size(); j++) {
input_size *= input_shape[j];
}
input_size_list_.push_back(input_size * sizeof(T));
len_axis_[i] = SizeToInt(input_shape[axis_]);

if (input_size == 0) {
input_num_--;
} else {
input_size_list_.push_back(input_size * sizeof(T));
len_axis_[current_dim] = SizeToInt(input_shape[axis_]);
current_dim++;
}
}
workspace_size_list_.push_back(sizeof(T *) * input_num_);
workspace_size_list_.push_back(sizeof(int) * input_num_);


+ 13
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h View File

@@ -99,6 +99,19 @@ class GpuKernel : public KernelMod {
if (index >= addr_list.size()) {
MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")";
}
if ((addr_list[index] == nullptr) || (addr_list[index]->addr == nullptr) || (addr_list[index]->size == 0)) {
MS_LOG(EXCEPTION) << "The device address is empty, address index: " << index;
}
return reinterpret_cast<T *>(addr_list[index]->addr);
}
template <typename T>
inline T *GetPossiblyNullDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t index) {
if (index >= addr_list.size()) {
MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")";
}
// Kernels may run normally without workspace, the addr_list[index] maybe nullptr.
if ((addr_list[index] == nullptr) || (addr_list[index]->size == 0)) {
return nullptr;


+ 4
- 6
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h View File

@@ -52,18 +52,16 @@ class BatchNormGpuKernel : public GpuKernel {
auto running_variance = GetDeviceAddress<float>(inputs, 4);
T *z = nullptr;
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
z = GetDeviceAddress<T>(inputs, 5);
z = GetPossiblyNullDeviceAddress<T>(inputs, 5);
}

auto y = GetDeviceAddress<T>(outputs, 0);
auto reserve_addr = GetDeviceAddress<float>(outputs, 2);
T *workspace_addr = nullptr;
if (workspace_size_ != 0) {
workspace_addr = GetDeviceAddress<T>(workspace, 0);
}
T *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);

const float alpha = 1;
const float beta = 0;
if (is_train_) {
auto reserve_addr = GetPossiblyNullDeviceAddress<float>(outputs, 2);
auto save_mean = GetDeviceAddress<float>(outputs, 3);
auto save_variance = GetDeviceAddress<float>(outputs, 4);
CHECK_CUDNN_RET_WITH_EXCEPT(


+ 4
- 6
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h View File

@@ -71,8 +71,6 @@ class BatchNormGradGpuKernel : public GpuKernel {
auto scale = GetDeviceAddress<float>(inputs, 2);
auto save_mean = GetDeviceAddress<float>(inputs, 3);
auto save_variance = GetDeviceAddress<float>(inputs, 4);
auto reserve_addr = GetDeviceAddress<float>(inputs, 5);
reserve_size_ = inputs[5]->size;
void *bias = nullptr;
T *y = nullptr;
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
@@ -88,11 +86,11 @@ class BatchNormGradGpuKernel : public GpuKernel {
dz = GetDeviceAddress<T>(outputs, 3);
}

void *workspace_addr = nullptr;
if (workspace_size_ != 0) {
workspace_addr = GetDeviceAddress<T>(workspace, 0);
}
if (is_train_) {
auto reserve_addr = GetPossiblyNullDeviceAddress<float>(inputs, 5);
reserve_size_ = inputs[5]->size;
void *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);

const float alpha_data_diff = 1;
const float alpha_param_diff = 1;
const float beta_param_diff = 0;


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h View File

@@ -58,8 +58,8 @@ class BiasAddGradGpuKernel : public GpuKernel {
"cudaMemcpyAsync failed.");
} else {
if (use_cudnn_) { // shared memory not satisfied or num_dim > 4
T *indices_addr = GetDeviceAddress<T>(workspace, 0);
T *workspace_addr = GetDeviceAddress<T>(workspace, 1);
T *indices_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);
T *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 1);
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(


+ 1
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h View File

@@ -46,10 +46,7 @@ class Conv2dGpuFwdKernel : public GpuKernel {
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *filter_addr = GetDeviceAddress<T>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
T *workspace_addr = nullptr;
if (workspace_size_ != 0) {
workspace_addr = GetDeviceAddress<T>(workspace, 0);
}
T *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);

const float alpha = 1;
const float beta = 0;


+ 2
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h View File

@@ -71,16 +71,13 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
T *dy = GetDeviceAddress<T>(inputs, 0);
T *x = GetDeviceAddress<T>(inputs, 1);
T *dw = GetDeviceAddress<T>(outputs, 0);
T *work_space = nullptr;
if (workspace_size_ != 0) {
work_space = GetDeviceAddress<T>(workspace, 0);
}
T *work_space = GetPossiblyNullDeviceAddress<T>(workspace, 0);

const float alpha = 1;
const float beta = 0;

if (use_pad_) {
T *padded = GetDeviceAddress<T>(workspace, 1);
T *padded = GetPossiblyNullDeviceAddress<T>(workspace, 1);
if (data_format_ == kOpFormat_NHWC) {
CalPadNHWC(padded_size_ / sizeof(T), x, n_, old_height_, old_width_, c_, old_height_ + pad_height_,
old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded,


+ 2
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h View File

@@ -74,14 +74,11 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
T *dy = GetDeviceAddress<T>(inputs, 0);
T *w = GetDeviceAddress<T>(inputs, 1);
T *dx = GetDeviceAddress<T>(outputs, 0);
T *work_space = nullptr;
if (workspace_size_ != 0) {
work_space = GetDeviceAddress<T>(workspace, 0);
}
T *work_space = GetPossiblyNullDeviceAddress<T>(workspace, 0);

const float alpha = 1;
if (use_pad_) {
T *padded = GetDeviceAddress<T>(workspace, 1);
T *padded = GetPossiblyNullDeviceAddress<T>(workspace, 1);

CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,


+ 2
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.h View File

@@ -45,15 +45,12 @@ class Conv3dGpuKernel : public GpuKernel {
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *filter_addr = GetDeviceAddress<T>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
T *workspace_addr = nullptr;
if (workspace_size_ != 0) {
workspace_addr = GetDeviceAddress<T>(workspace, 0);
}
T *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);

const float alpha = 1;
const float beta = 0;
if (use_pad_) {
T *padded_addr = GetDeviceAddress<T>(workspace, 1);
T *padded_addr = GetPossiblyNullDeviceAddress<T>(workspace, 1);
CalPad3d(padded_size_ / sizeof(T), input_addr, n_, c_, old_depth_, old_height_, old_width_,
old_depth_ + pad_depth_, old_height_ + pad_height_, old_width_ + pad_width_, pad_head_, pad_top_,
pad_left_, pad_value_, padded_addr, reinterpret_cast<cudaStream_t>(stream_ptr));


+ 2
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_filter_gpu_kernel.h View File

@@ -47,10 +47,7 @@ class Conv3dGradFilterGpuKernel : public GpuKernel {
T *x = GetDeviceAddress<T>(inputs, 0);
T *dy = GetDeviceAddress<T>(inputs, 1);

T *work_space = nullptr;
if (workspace_size_ != 0) {
work_space = GetDeviceAddress<T>(workspace, 0);
}
T *work_space = GetPossiblyNullDeviceAddress<T>(workspace, 0);

T *dw = nullptr;
float *dw_float32 = nullptr;
@@ -64,7 +61,7 @@ class Conv3dGradFilterGpuKernel : public GpuKernel {
const float alpha = 1;
const float beta = 0;
if (use_pad_) {
T *padded = GetDeviceAddress<T>(workspace, 1);
T *padded = GetPossiblyNullDeviceAddress<T>(workspace, 1);
CalPad3d(padded_size_ / sizeof(T), x, n_, c_, old_depth_, old_height_, old_width_, old_depth_ + pad_depth_,
old_height_ + pad_height_, old_width_ + pad_width_, pad_head_, pad_top_, pad_left_, pad_value_, padded,
reinterpret_cast<cudaStream_t>(stream_ptr));


+ 1
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_input_gpu_kernel.h View File

@@ -46,10 +46,7 @@ class Conv3dGradInputGpuKernel : public GpuKernel {
T *w = GetDeviceAddress<T>(inputs, 0);
T *dy = GetDeviceAddress<T>(inputs, 1);
T *dx = GetDeviceAddress<T>(outputs, 0);
T *work_space = nullptr;
if (workspace_size_ != 0) {
work_space = GetDeviceAddress<T>(workspace, 0);
}
T *work_space = GetPossiblyNullDeviceAddress<T>(workspace, 0);

const float alpha = 1;
if (use_pad_) {


+ 2
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_transpose_gpu_kernel.h View File

@@ -46,14 +46,11 @@ class Conv3dTransposeGpuFwdKernel : public GpuKernel {
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *filter_addr = GetDeviceAddress<T>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
T *work_space = nullptr;
if (workspace_size_ != 0) {
work_space = GetDeviceAddress<T>(workspace, 0);
}
T *work_space = GetPossiblyNullDeviceAddress<T>(workspace, 0);
const float alpha = 1;
if (use_pad_) {
T *padded = GetDeviceAddress<T>(workspace, 1);
T *padded = GetPossiblyNullDeviceAddress<T>(workspace, 1);
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnConvolutionBackwardData(cudnn_handle_, &alpha, filter_desc_, filter_addr,
input_desc_, input_addr, conv_desc_, algo_, work_space,


+ 1
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h View File

@@ -75,10 +75,7 @@ class InstanceNormGpuKernel : public GpuKernel {
float *ws_beta = GetDeviceAddress<float>(workspace, 1);
float *ws_mean = GetDeviceAddress<float>(workspace, 2);
float *ws_var = GetDeviceAddress<float>(workspace, 3);
T *workspace_addr = nullptr;
if (workspace_size_ != 0) {
workspace_addr = GetDeviceAddress<T>(workspace, 4);
}
T *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 4);

size_t N = input_shape_[0];
size_t C = input_shape_[1];


+ 1
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h View File

@@ -78,10 +78,7 @@ class InstanceNormGradGpuKernel : public GpuKernel {
float *ws_gamma = GetDeviceAddress<float>(workspace, 0);
float *ws_dgamma = GetDeviceAddress<float>(workspace, 1);
float *ws_dbeta = GetDeviceAddress<float>(workspace, 2);
void *workspace_addr = nullptr;
if (workspace_size_ != 0) {
workspace_addr = GetDeviceAddress<T>(workspace, 3);
}
void *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 3);

size_t N = input_shape_[0];
size_t C = input_shape_[1];


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h View File

@@ -59,8 +59,8 @@ class L2NormalizeGpuKernel : public GpuKernel {
}
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
T *reduce_workspace_addr = GetDeviceAddress<T>(workspace, 0);
T *workspace_addr = GetDeviceAddress<T>(workspace, 1);
T *reduce_workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);
T *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 1);

const float alpha = 1;
const float beta = 0;


+ 4
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h View File

@@ -62,10 +62,10 @@ class L2NormalizeGradGpuKernel : public GpuKernel {
T *y_addr = GetDeviceAddress<T>(inputs, 1);
T *dy_addr = GetDeviceAddress<T>(inputs, 2);
T *dx_addr = GetDeviceAddress<T>(outputs, 0);
T *reduce_workspace_addr = GetDeviceAddress<T>(workspace, 0);
T *reduce_y_dy_workspace_addr = GetDeviceAddress<T>(workspace, 1);
T *workspace_addr = GetDeviceAddress<T>(workspace, 2);
T *workspace_y_dy_addr = GetDeviceAddress<T>(workspace, 3);
T *reduce_workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);
T *reduce_y_dy_workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 1);
T *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 2);
T *workspace_y_dy_addr = GetPossiblyNullDeviceAddress<T>(workspace, 3);

const float alpha = 1;
const float beta = 0;


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h View File

@@ -70,7 +70,7 @@ class LstmGpuKernel : public GpuKernel {
auto cy_addr = GetDeviceAddress<T>(outputs, 2);
auto reserved_addr = GetDeviceAddress<T>(outputs, 3);
auto states_addr = GetDeviceAddress<T>(outputs, 4);
void *workspace_addr = GetDeviceAddress<T>(workspace, 0);
void *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);

if (!states_init_) {
CHECK_CUDNN_RET_WITH_EXCEPT(


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h View File

@@ -75,7 +75,7 @@ class LstmGradDataGpuKernel : public GpuKernel {
auto dx_addr = GetDeviceAddress<T>(outputs, 0);
auto dhx_addr = GetDeviceAddress<T>(outputs, 1);
auto dcx_addr = GetDeviceAddress<T>(outputs, 2);
void *workspace_addr = GetDeviceAddress<T>(workspace, 0);
void *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);

if (!states_init_) {
CHECK_CUDNN_RET_WITH_EXCEPT(


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h View File

@@ -44,7 +44,7 @@ class NLLLossGpuKernel : public GpuKernel {
T *loss_device = GetDeviceAddress<T>(outputs, 0);
S *total_weight_device = GetDeviceAddress<S>(outputs, 1);

T *tmp_loss_device = GetDeviceAddress<T>(workspace, 0);
T *tmp_loss_device = GetPossiblyNullDeviceAddress<T>(workspace, 0);
S *tmp_target_weight_device = GetDeviceAddress<S>(workspace, 1);

NLLLoss(n_, c_, reduction_, input_device, target_device, weight_device, loss_device, total_weight_device,


Loading…
Cancel
Save