Browse Source

!10169 fix bug in example of MakeRefKey

From: @zhangbuxue
Reviewed-by: @zh_qh,@zhunaipan
Signed-off-by: @zh_qh
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f3aee78bde
2 changed files with 7 additions and 4 deletions
  1. +1
    -1
      mindspore/ccsrc/pipeline/pynative/base.h
  2. +6
    -3
      mindspore/ops/operations/other_ops.py

+ 1
- 1
mindspore/ccsrc/pipeline/pynative/base.h View File

@@ -66,7 +66,7 @@ struct OpExecInfo {
};
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;

const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
const std::set<std::string> ignore_infer_prim = {"mixed_precision_cast"};
const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"};
const std::set<std::string> ignore_judge_dynamic_cell = {
"Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal",


+ 6
- 3
mindspore/ops/operations/other_ops.py View File

@@ -361,11 +361,14 @@ class MakeRefKey(Primitive):
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> from mindspore.ops import functional as ops
>>> import numpy as np
>>> from mindspore import Parameter, Tensor
>>> from mindspore import dtype as mstype
>>> import mindspore.ops as ops
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.y = mindspore.Parameter(Tensor(np.ones([6, 8, 10]), mindspore.int32), name="y")
... self.y = Parameter(Tensor(np.ones([6, 8, 10]), mstype.int32), name="y")
... self.make_ref_key = ops.MakeRefKey("y")
...
... def construct(self, x):
@@ -373,7 +376,7 @@ class MakeRefKey(Primitive):
... ref = ops.make_ref(key, x, self.y)
... return ref * x
...
>>> x = Tensor(np.ones([3, 4, 5]), mindspore.int32)
>>> x = Tensor(np.ones([3, 4, 5]), mstype.int32)
>>> net = Net()
>>> output = net(x)
>>> print(output)


Loading…
Cancel
Save