Browse Source

!2720 fix assign used in while loop

Merge pull request !2720 from xychow/fix-assign-in-while
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
ea475637a1
5 changed files with 69 additions and 21 deletions
  1. +1
    -1
      mindspore/ccsrc/pipeline/pipeline.cc
  2. +20
    -0
      mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
  3. +1
    -0
      mindspore/ccsrc/pipeline/static_analysis/abstract_value.h
  4. +2
    -0
      mindspore/ccsrc/pipeline/static_analysis/evaluator.cc
  5. +45
    -20
      tests/ut/python/pipeline/infer/test_net_infer.py

+ 1
- 1
mindspore/ccsrc/pipeline/pipeline.cc View File

@@ -314,7 +314,7 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
auto weight_name = weight_node->cast<ParameterPtr>()->name();
// find the fakequant from input
int count = 0;
int max_depth = 5;
const int max_depth = 5;
while (!is_quant_cnode(x)) {
if (count >= max_depth) {
break;


+ 20
- 0
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc View File

@@ -451,6 +451,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
if (other_tensor == nullptr) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
}
if (*this == *other) {
if (sparse_grad() == other->sparse_grad()) {
return shared_from_base<AbstractBase>();
}
}
auto element = element_->Join(other_tensor->element_);
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
auto ret = std::make_shared<AbstractTensor>(element, shape);
@@ -830,6 +835,21 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
return false;
}

AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
auto other_ref = other->cast<AbstractRefPtr>();
if (other_ref == nullptr) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
}
if (*this == *other) {
return shared_from_base<AbstractBase>();
}
auto ref_key = ref_key_->Join(other_ref->ref_key_);
auto ref = ref_->Join(other_ref->ref());
auto ref_origin = ref_origin_->Join(other_ref->ref_origin_);

return std::make_shared<AbstractRef>(ref_key, ref, ref_origin);
}

std::string AbstractRef::ToString() const {
std::ostringstream buffer;
buffer << type_name() << "("


+ 1
- 0
mindspore/ccsrc/pipeline/static_analysis/abstract_value.h View File

@@ -578,6 +578,7 @@ class AbstractRef : public AbstractBase {
AbstractBasePtr Broaden() const override {
return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden());
}
AbstractBasePtr Join(const AbstractBasePtr &other) override;
std::size_t hash() const override {
return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1);
}


+ 2
- 0
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc View File

@@ -166,6 +166,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
if (!(joined_args_spec_list == args_spec_list)) {
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
return joined_args_spec_list;
}
@@ -179,6 +180,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
if (!(joined_args_spec_list == args_spec_list)) {
trace_.push_back(joined_args_spec_list);
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);
return joined_args_spec_list;


+ 45
- 20
tests/ut/python/pipeline/infer/test_net_infer.py View File

@@ -16,29 +16,54 @@
import numpy as np

import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Tensor, context
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
import mindspore.ops.operations as op

def test_net_infer():
""" test_net_infer """
class Net(nn.Cell):
""" Net definition """

class Net(nn.Cell):
""" Net definition """
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
self.bn = nn.BatchNorm2d(64)
self.fc = nn.Dense(64, 10)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()

def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
self.bn = nn.BatchNorm2d(64)
self.fc = nn.Dense(64, 10)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.flatten(x)
out = self.fc(x)
return out
Tensor(np.random.randint(0, 255, [1, 3, 224, 224]))
Net()

def construct(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.flatten(x)
out = self.fc(x)
return out

def test_assign_in_while():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self, input_shape):
super().__init__()
self.assign = op.Assign()
self.inputdata = Parameter(initializer(1, input_shape), name="global_step")

def test_net_infer():
""" test_net_infer """
Tensor(np.random.randint(0, 255, [1, 3, 224, 224]))
Net()
def construct(self, x, y, z):
out = z
while x < y:
inputdata = self.inputdata
x = x + 1
out = self.assign(inputdata, z)
return out

x = Tensor(np.array(1).astype(np.int32))
y = Tensor(np.array(3).astype(np.int32))
input_shape = (1024, 512)
z = Tensor(np.random.randn(*input_shape).astype(np.float32))
net = Net(input_shape)
ret = net(x, y, z)
assert ret == z

Loading…
Cancel
Save