Browse Source

!3655 gpu support BroadcastTo kernel

Merge pull request !3655 from chenweifeng/broadcast_to
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
f1a39a0f72
2 changed files with 10 additions and 3 deletions
  1. +4
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h
  2. +6
    -0
      tests/st/ops/gpu/test_broadcast_to_ops.py

+ 4
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h View File

@@ -51,11 +51,12 @@ class BroadcastToGpuKernel : public GpuKernel {
MS_LOG(EXCEPTION) << "BroadcastTo operation not support dim greater than 4";
}

for (int i = input_shapes.size() - 1; i >= 0; i--) {
input_shape_[i] = input_shapes[i];
size_t offset = output_shapes.size() - input_shapes.size();
for (size_t i = 0; i < input_shapes.size(); i++) {
input_shape_[i + offset] = input_shapes[i];
}

for (int j = output_shapes.size() - 1; j >= 0; j--) {
for (size_t j = 0; j < output_shapes.size(); j++) {
output_shape_[j] = output_shapes[j];
}



+ 6
- 0
tests/st/ops/gpu/test_broadcast_to_ops.py View File

@@ -38,3 +38,9 @@ def test_broadcast():
output = P.BroadcastTo(shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, shape)
assert np.allclose(output.asnumpy(), expect)

x1_np = np.random.rand(4, 5).astype(np.float32)
shape = (2, 3, 4, 5)
output = P.BroadcastTo(shape)(Tensor(x1_np))
expect = np.broadcast_to(x1_np, shape)
assert np.allclose(output.asnumpy(), expect)

Loading…
Cancel
Save