Browse Source

fix bugs when shape is empty

tags/v1.1.0
baihuawei 5 years ago
parent
commit
d2e345b3d6
3 changed files with 11 additions and 5 deletions
  1. +3
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc
  2. +4
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h
  3. +4
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h

+ 3
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc View File

@@ -40,6 +40,9 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
CheckAxis(kernel_node);
if (shape_.empty()) {
shape_.push_back(1);
}
for (size_t i = 0; i < shape_.size(); ++i) {
if (shape_[i] <= 0) {
MS_LOG(EXCEPTION) << "shape value is invalid.";


+ 4
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h View File

@@ -40,8 +40,10 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel {
T *weight = GetDeviceAddress<T>(inputs, 2);
T *loss = GetDeviceAddress<T>(outputs, 0);
T *tmp_loss = GetDeviceAddress<T>(workspace, 0);
BinaryCrossEntropyLoss(input_size_, reduction_, input_x, input_y, weight, loss, tmp_loss,
reinterpret_cast<cudaStream_t>(stream_ptr));
if (input_size_ > 0) {
BinaryCrossEntropyLoss(input_size_, reduction_, input_x, input_y, weight, loss, tmp_loss,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}


+ 4
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h View File

@@ -42,8 +42,10 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
T *dloss = GetDeviceAddress<T>(inputs, 2);
T *weight = GetDeviceAddress<T>(inputs, 3);
T *dx = GetDeviceAddress<T>(outputs, 0);
BinaryCrossEntropyLossGrad(input_size_, reduction_, input_x, input_y, weight, dloss, dx,
reinterpret_cast<cudaStream_t>(stream_ptr));
if (input_size_ > 0) {
BinaryCrossEntropyLossGrad(input_size_, reduction_, input_x, input_y, weight, dloss, dx,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
@@ -52,7 +54,6 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
if (reduction == "none") {
reduction_ = 0;


Loading…
Cancel
Save