Browse Source

optimizes the kernel error description of GPU about FakeLearnedScaleQuantPerChannelGrad etc.

tags/v1.6.0
tacyi139 4 years ago
parent
commit
37e74206f9
16 changed files with 97 additions and 96 deletions
  1. +5
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.cc
  2. +3
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_gpu_kernel.cc
  3. +5
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.cc
  4. +9
    -11
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.cc
  5. +9
    -7
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc
  6. +8
    -7
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.cc
  7. +8
    -6
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc
  8. +5
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.cc
  9. +4
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.cc
  10. +6
    -8
      mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_categorical_gpu_kernel.h
  11. +6
    -9
      mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h
  12. +8
    -7
      mindspore/ccsrc/backend/kernel_compiler/gpu/random/randperm_gpu_kernel.h
  13. +9
    -10
      mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.h
  14. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.cc
  15. +5
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/rl/discounted_return_gpu_kernel.h
  16. +5
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.cc

+ 5
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.cc View File

@@ -40,22 +40,22 @@ const std::vector<size_t> &FakeLearnedScaleQuantPerChannelGradGpuKernel::GetWork
}
bool FakeLearnedScaleQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) {
MS_LOG(EXCEPTION) << "Input number is " << input_num
<< ", but FakeLearnedScaleQuantPerChannelGrad GpuKernel OP needs 4 input.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 4, but got " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(EXCEPTION) << "Output number is " << output_num
<< ", but FakeLearnedScaleQuantPerChannelGrad GpuKernel OP needs 2 output.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 2, but got " << output_num;
}
quant_delay_ = static_cast<int>(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less than 0, require larger than 0.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of quant_delay_ cannot be less than 0, but got "
<< quant_delay_;
}
neg_trunc_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("neg_trunc"));


+ 3
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_gpu_kernel.cc View File

@@ -37,17 +37,16 @@ const std::vector<size_t> &FakeLearnedScaleQuantPerLayerGpuKernel::GetWorkspaceS
}
bool FakeLearnedScaleQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num
<< ", but FakeLearnedScaleQuantPerLayer GpuKernel OP needs 3 Input.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 3, but got " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num
<< ", but FakeLearnedScaleQuantPerLayer GpuKernel OP needs 1 output.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 1, but got " << output_num;
}
quant_delay_ = static_cast<int>(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")));


+ 5
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.cc View File

@@ -34,22 +34,22 @@ const std::vector<size_t> &FakeLearnedScaleQuantPerLayerGradGpuKernel::GetWorksp
}
bool FakeLearnedScaleQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) {
MS_LOG(EXCEPTION) << "Input number is " << input_num
<< ", but FakeLearnedScaleQuantPerLayerGrad GpuKernel OP needs 4 input.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 4, but got " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(EXCEPTION) << "Output number is " << output_num
<< ", but FakeLearnedScaleQuantPerLayerGrad GpuKernel OP needs 2 output.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 2, but got " << output_num;
}
quant_delay_ = static_cast<int>(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less than 0, require larger than 0.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of quant_delay_ cannot be less than 0, but got "
<< quant_delay_;
}
neg_trunc_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("neg_trunc"));


+ 9
- 11
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.cc View File

@@ -43,17 +43,16 @@ const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetOutputSizeList() con
const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 input.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 3, but got " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << " but FakeQuant GpuKernel OP needs 1 output.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 1, but got " << output_num;
}
// get attribute
@@ -66,13 +65,13 @@ bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
quant_delay_ = static_cast<int>(GetValue<int64_t>(prim->GetAttr("quant_delay")));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of num_bits should be in (2, 16), but got "
<< num_bits_;
}
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of quant_delay_ cannot be less than 0, but got "
<< quant_delay_;
}
// quant min and max value
@@ -84,14 +83,13 @@ bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
// shape info for gpu
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
if (is_null_input_) {
MS_LOG(WARNING) << "For 'FakeQuantPerchannelGpuKernel', input is null";
InitSizeLists();
return true;
}
if (input_shape.empty()) {
MS_LOG(EXCEPTION) << "For 'FakeQuantPerchannelGpuKernel', input_shape is empty.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', input cannot be empty, but got empty";
}
num_channels_ = SizeToInt(input_shape[0]);
input_size_ = sizeof(float);


+ 9
- 7
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc View File

@@ -40,27 +40,30 @@ const std::vector<size_t> &FakeQuantPerChannelGradGpuKernel::GetWorkspaceSizeLis
}
bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 4, but got " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 1, but got " << output_num;
}
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
num_bits_ = static_cast<unsigned int>(GetValue<int64_t>(prim->GetAttr("num_bits")));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of num_bits should be in (2, 16), but got "
<< num_bits_;
}
quant_delay_ = static_cast<int>(GetValue<int64_t>(prim->GetAttr("quant_delay")));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of quant_delay_ cannot be less than 0, but got "
<< quant_delay_;
}
symmetric_ = GetValue<bool>(prim->GetAttr("symmetric"));
@@ -74,14 +77,13 @@ bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) {
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
if (is_null_input_) {
MS_LOG(WARNING) << "For 'FakeQuantPerchannelGradGpuKernel', input is null";
InitSizeLists();
return true;
}
if (input_shape.empty()) {
MS_LOG(EXCEPTION) << "For 'FakeQuantPerchannelGradGpuKernel', input_shape is empty.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', input cannot be empty, but got empty";
}
num_channels_ = SizeToInt(input_shape[0]);
input_size_ = sizeof(float);


+ 8
- 7
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.cc View File

@@ -43,15 +43,16 @@ const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetOutputSizeList() const
const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool FakeQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 3, but got " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 1, but got " << output_num;
}
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
@@ -63,13 +64,14 @@ bool FakeQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
narrow_range_ = GetValue<bool>(prim->GetAttr("narrow_range"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of num_bits should be in (2, 16), but got "
<< num_bits_;
}
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of quant_delay_ cannot be less than 0, but got "
<< quant_delay_;
}
// quant min and max value
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
@@ -79,9 +81,8 @@ bool FakeQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
if (is_null_input_) {
MS_LOG(WARNING) << "For 'FakeQuantPerlayerGpuKernel', input is null";
InitSizeLists();
return true;
}


+ 8
- 6
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc View File

@@ -39,27 +39,30 @@ const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetOutputSizeList() c
const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool FakeQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 4, but got " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 1, but got " << output_num;
}
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
num_bits_ = static_cast<unsigned int>(GetValue<int64_t>(prim->GetAttr("num_bits")));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of num_bits should be in (2, 16), but got "
<< num_bits_;
}
quant_delay_ = static_cast<int>(GetValue<int64_t>(prim->GetAttr("quant_delay")));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of quant_delay_ cannot be less than 0, but got "
<< quant_delay_;
}
symmetric_ = GetValue<bool>(prim->GetAttr("symmetric"));
@@ -74,9 +77,8 @@ bool FakeQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) {
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
if (is_null_input_) {
MS_LOG(WARNING) << "For 'FakeQuantPerlayerGradGpuKernel', input is null";
InitSizeLists();
return true;
}


+ 5
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.cc View File

@@ -35,14 +35,15 @@ const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetWorkspaceSizeList
}

bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 3, but got " << input_num;
}

size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 2, but got " << output_num;
}

auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
@@ -52,14 +53,13 @@ bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) {

// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
if (is_null_input_) {
MS_LOG(WARNING) << "For 'MinMaxUpdatePerchannelGpuKernel', input is null";
InitSizeLists();
return true;
}
if (input_shape.empty()) {
MS_LOG(EXCEPTION) << "For 'MinMaxUpdatePerchannelGpuKernel', input_shape is empty.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', input cannot be empty, but got empty";
}
num_channels_ = SizeToInt(input_shape[0]);
for (size_t i = 0; i < input_shape.size(); ++i) {


+ 4
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.cc View File

@@ -33,14 +33,15 @@ const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetOutputSizeList() co
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }

bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 3, but got " << input_num;
}

size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 2, but got " << output_num;
}

auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
@@ -50,9 +51,8 @@ bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) {

// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
if (is_null_input_) {
MS_LOG(WARNING) << "For 'MinMaxUpdatePerlayerGpuKernel', input is null";
InitSizeLists();
return true;
}


+ 6
- 8
mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_categorical_gpu_kernel.h View File

@@ -92,28 +92,26 @@ class RandomCategoricalGpuKernel : public GpuKernel {
}

bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomCategorical needs 3 inputs.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 3, but got " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but RandomCategorical has 1 output.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 1, but got " << output_num;
}
auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(logits_shape);
is_null_input_ = CHECK_SHAPE_NULL(logits_shape, kernel_name, "logits");
if (is_null_input_) {
MS_LOG(WARNING) << "For 'RandomCategoricalGpuKernel', input is null";
InitSizeLists();
return true;
}
if (logits_shape.size() != 2) {
MS_LOG(ERROR) << "logits's dims is " << logits_shape.size() << ", but it should be only 2-D.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of logits should be 2, but got "
<< logits_shape.size();
}
batch_size_ = logits_shape[0];
num_classes_ = logits_shape[1];


+ 6
- 9
mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h View File

@@ -73,30 +73,27 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel {
}

bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node);
uint32_t time_interval = std::chrono::system_clock::now().time_since_epoch().count();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomChoiceWithMask needs 1 input.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 1, but got " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but RandomChoiceWithMask has 2 outputs.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs should be 2, but got " << output_num;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
if (is_null_input_) {
MS_LOG(WARNING) << "For 'RandomChoiceWithMaskGpuKernel', input is null";
InitSizeLists();
return true;
}
input_shape_size_ = input_shape.size();
if (input_shape_size_ < 1 || input_shape_size_ > MAX_DIMENSION) {
MS_LOG(ERROR) << "Input is " << input_shape_size_
<< "-D, but RandomChoiceWithMask supports only 1-D to 5-D inputs.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input should be in (1, 5), but got "
<< input_shape_size_;
}
// convert size_t to int
for (auto i = 0; i < input_shape_size_; i++) {


+ 8
- 7
mindspore/ccsrc/backend/kernel_compiler/gpu/random/randperm_gpu_kernel.h View File

@@ -20,6 +20,7 @@
#include <cstdint>
#include <random>
#include <vector>
#include <string>

#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
@@ -52,7 +53,8 @@ class RandpermGpuKernel : public GpuKernel {
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed in RandpermGpuKernel");

if (static_cast<size_t>(n) > max_length_) {
MS_LOG(EXCEPTION) << "RandpermGpuKernel: n (" << n << ") cannot exceed max_length_ (" << max_length_ << ")";
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', n (" << n << ") cannot exceed max_length_ (" << max_length_
<< ")";
}

// might not be a significant performance gain if this kernel is executed in cuda,
@@ -71,22 +73,21 @@ class RandpermGpuKernel : public GpuKernel {
}

bool Init(const CNodePtr &kernel_node) override {
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_count != 1) {
MS_LOG(ERROR) << input_count << " inputs were provided, but RandpermGpuKernel expects 1.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be 1, but got " << input_count;
}

size_t output_count = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_count != 1) {
MS_LOG(ERROR) << "Number of outputs is " << output_count << ", but should be 1 for RandpermGpuKernel.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs should be 1, but got " << output_count;
}

max_length_ = static_cast<size_t>(GetAttr<int64_t>(kernel_node, "max_length"));
if (max_length_ < 1) {
MS_LOG(ERROR) << "For 'RandpermGpuKernel', the max_length cannot be less than 1, but got " << max_length_;
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of max_length cannot be less than 1, but got "
<< max_length_;
}
pad_ = static_cast<T>(GetAttr<int64_t>(kernel_node, "pad"));



+ 9
- 10
mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.h View File

@@ -20,6 +20,7 @@
#include <cmath>
#include <set>
#include <vector>
#include <string>
#include <random>
#include <limits>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
@@ -85,17 +86,16 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
}

bool Init(const CNodePtr &kernel_node) override {
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but UniformCandidateSampler needs 1 input.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be 1, but got " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 3) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but UniformCandidateSampler has 3 outputs.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs should be 3, but got " << output_num;
}
// getting attrs
num_true_ = GetAttr<int64_t>(kernel_node, "num_true");
@@ -112,15 +112,14 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
}

auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input");
if (is_null_input_) {
MS_LOG(WARNING) << "For 'UniformCandidateSamplerGpuKernel', input is null";
InitSizeLists();
return true;
}
if (input_shape.size() != 2) {
MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but UniformCandidateSampler supports only 2-D inputs.";
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input should be 2, but got "
<< input_shape.size();
}
input_size_ = input_shape[0] * input_shape[1];
if (num_sampled_ + static_cast<int64_t>(input_size_) > range_max_) {
@@ -157,7 +156,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
// pick between [0, range_max_-1]
T range;
if (range_max_ > static_cast<int64_t>(std::numeric_limits<T>::max())) {
MS_LOG(EXCEPTION) << "range_max_ failed to cast";
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', range_max_ failed to cast";
}
range = static_cast<T>(range_max_);
std::uniform_int_distribution<T> distribution(0, range - 1);
@@ -186,7 +185,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
S Probability() {
S range;
if (range_max_ > static_cast<int64_t>(std::numeric_limits<S>::max())) {
MS_LOG(EXCEPTION) << "range_max_ failed to cast";
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', range_max_ failed to cast";
}
range = static_cast<S>(range_max_);
MS_EXCEPTION_IF_ZERO("range", range);


+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.cc View File

@@ -47,6 +47,7 @@ const std::vector<size_t> &BufferSampleKernel::GetOutputSizeList() const { retur
const std::vector<size_t> &BufferSampleKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }

bool BufferSampleKernel::Init(const CNodePtr &kernel_node) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
auto shapes = GetAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
auto types = GetAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
@@ -66,7 +67,7 @@ bool BufferSampleKernel::Init(const CNodePtr &kernel_node) {
const size_t cap_state_size = sizeof(curandState) * indexes_size;
void *dev_state = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(cap_state_size);
if (dev_state == nullptr) {
MS_LOG(EXCEPTION) << "Failed to alloc dev_state, size is " << cap_state_size;
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', failed to alloc dev_state, size is " << cap_state_size;
}
devStates_ = reinterpret_cast<curandState *>(dev_state);



+ 5
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/rl/discounted_return_gpu_kernel.h View File

@@ -35,19 +35,20 @@ class DiscountedReturnGpuKernel : public GpuKernel {
~DiscountedReturnGpuKernel() = default;

bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node);
gamma_ = AnfAlgo::GetNodeAttr<float>(kernel_node, kGammaAttrName);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != kInputNum) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but DiscountedReturnGpuKernel needs " << kInputNum;
return false;
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be " << kInputNum << ", but got "
<< input_num;
}

const std::vector<size_t> &reward_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
const std::vector<size_t> &done_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
if (reward_shape.size() == 0) {
MS_LOG(ERROR) << "Reward is " << reward_shape.size()
<< "-D, but DiscountedReturnGpuKernel supports 1-D and above.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of reward cannot be 0, but got "
<< reward_shape.size();
}

// Reshape reward to [timestep, env, else], done to [timestep, env], last_value to [env, else].


+ 5
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.cc View File

@@ -27,6 +27,7 @@ const std::vector<size_t> &TrtKernel::GetOutputSizeList() const { return output_
const std::vector<size_t> &TrtKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }

bool TrtKernel::Init(const CNodePtr &kernel_node) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 0; i < input_num; i++) {
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
@@ -47,7 +48,8 @@ bool TrtKernel::Init(const CNodePtr &kernel_node) {

auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
if (!trt_loader.nvinfer_loaded()) {
MS_LOG(EXCEPTION) << "Install Tensor-RT and export LD_LIBRARY_PATH=${TENSORRT_HOME}/lib:$LD_LIBRARY_PATH.";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', install Tensor-RT and export LD_LIBRARY_PATH=${TENSORRT_HOME}"
<< "/lib:$LD_LIBRARY_PATH.";
}
runtime_ = trt_loader.CreateInferRuntime(&Singleton<TrtLogger>::Instance());
MS_EXCEPTION_IF_NULL(runtime_);
@@ -55,8 +57,8 @@ bool TrtKernel::Init(const CNodePtr &kernel_node) {
engine_ = TrtPtr(runtime_->deserializeCudaEngine(serialize_.c_str(), serialize_.size(), nullptr));
MS_EXCEPTION_IF_NULL(engine_);
if (SizeToInt(input_num + output_num) != engine_->getNbBindings()) {
MS_LOG(EXCEPTION) << "Inputs and outputs num not match. Got: " << input_num + output_num
<< ", expect: " << engine_->getNbBindings();
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs add the number of outputs should be "
<< engine_->getNbBindings() << ", but got " << (input_num + output_num);
}

context_ = TrtPtr(engine_->createExecutionContext());


Loading…
Cancel
Save