Browse Source

!29265 fix resize_bilinear infer

Merge pull request !29265 from jiangzhenguang/resize_bilinear
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
7bb5819889
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 15 additions and 1 deletions
  1. +1
    -1
      mindspore/python/mindspore/ops/operations/_grad_ops.py
  2. +14
    -0
      tests/ut/python/parallel/test_resizebilinear.py

+ 1
- 1
mindspore/python/mindspore/ops/operations/_grad_ops.py View File

@@ -2339,7 +2339,7 @@ class ParallelResizeBilinearGrad(PrimitiveWithInfer):
x_dtype = x['dtype']
validator.check_tensor_dtype_valid("grad_dtype", grad_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32], self.name)
if size_val is not None:
if size_val is None:
raise ValueError("size should be const input")
output_shape = [grad_shape[0], grad_shape[1], x_shape[2], x_shape[3]]



+ 14
- 0
tests/ut/python/parallel/test_resizebilinear.py View File

@@ -239,3 +239,17 @@ def test_neighbor_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
compile_net(net)


def test_bilinear_shard_n_c_w():
"""
Feature: test ResizeBilinear shard n/c/w
Description: shard n/c/w
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=3)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 2, 1, 2),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)

Loading…
Cancel
Save