Browse Source

!2914 make AbstractRef can join with AbstractTensor

Merge pull request !2914 from xychow/fix-abstractref-join
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
7cdd5581f9
2 changed files with 24 additions and 2 deletions
  1. +2
    -1
      mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
  2. +22
    -1
      tests/ut/python/parameter_feature/test_var_grad.py

+ 2
- 1
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc View File

@@ -838,7 +838,8 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
auto other_ref = other->cast<AbstractRefPtr>();
if (other_ref == nullptr) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
auto new_ref = ref_->Join(other);
return std::make_shared<AbstractRef>(ref_key_, new_ref, ref_origin_);
}
if (*this == *other) {
return shared_from_base<AbstractBase>();


+ 22
- 1
tests/ut/python/parameter_feature/test_var_grad.py View File

@@ -22,7 +22,7 @@ from mindspore.common.parameter import ParameterTuple
from mindspore.nn import Cell
from mindspore.ops import operations as P

context.set_context(mode=context.GRAPH_MODE)
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)


def test_net_vargs_expand():
@@ -184,6 +184,27 @@ def test_grad_var_args_with_sens():
_ = grad_net(x, y, sens)


def test_grad_with_param_sens():
""""test grad_with_sens parameter"""

class GradNet(Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.weights = ParameterTuple(net.trainable_params())
self.net = net
self.sens = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), name='sens', requires_grad=False)
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)

def construct(self, x, y):
return self.grad(self.net, self.weights)(x, y, self.sens)

x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = SecondNet()
grad_net = GradNet(net)
_ = grad_net(x, y)


def test_var_args_grad():
class VarNet(Cell):
def __init__(self, net):


Loading…
Cancel
Save