Browse Source

!7238 [DynamicShape] Add dyanmic shape CI test case

Merge pull request !7238 from caifubi/dynamic_shape
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
b2160eadc4
2 changed files with 51 additions and 0 deletions
  1. +7
    -0
      mindspore/ccsrc/backend/session/executor.cc
  2. +44
    -0
      tests/st/ops/ascend/dynamic_shape/test_unique.py

+ 7
- 0
mindspore/ccsrc/backend/session/executor.cc View File

@@ -39,6 +39,13 @@ void UpdateOutputTensors(const VectorRef *outputs,
auto &output_index = iter->second.second;
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
tensor->set_device_address(address);

if (AnfAlgo::IsDynamicShape(node)) {
auto updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
ShapeVector int_shape;
std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
tensor->set_shape(int_shape);
}
}
if (tensor->NeedSyncDeviceToHostImmediately()) {
tensor->data_sync(false);


+ 44
- 0
tests/st/ops/ascend/dynamic_shape/test_unique.py View File

@@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.unique = P.Unique()

def construct(self, x):
return self.unique(x)

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_unqiue():
x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32)
unique = Net()
output = unique(x)
expect1 = np.array([1, 2, 3])
expect2 = np.array([0, 0, 1, 1, 2, 2])
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()

Loading…
Cancel
Save