Browse Source

!2040 fix paramter is metatensor bug in pynative mode

Merge pull request !2040 from flywind/fix_pynative_bug
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
89fce0e41f
2 changed files with 5 additions and 1 deletions
  1. +5
    -0
      mindspore/common/parameter.py
  2. +0
    -1
      mindspore/nn/cell.py

+ 5
- 0
mindspore/common/parameter.py View File

@@ -16,6 +16,7 @@
"""Parameter for cell."""
import numbers
from copy import copy, deepcopy
from mindspore import context
from . import dtype as mstype
from .initializer import initializer, Initializer
from .tensor import Tensor, MetaTensor
@@ -61,6 +62,8 @@ class Parameter:
self._is_init = False
self._sliced = False
self.clone_info = _CloneInfo()
if context.get_context("mode") == context.PYNATIVE_MODE:
self.init_data()

def __repr__(self):
format_str = 'Parameter (name={name})'
@@ -142,6 +145,8 @@ class Parameter:
if isinstance(init, (str, Initializer, numbers.Number)):
x.init_mode = initializer(init, shape=shape, dtype=dtype)
x.default_input = MetaTensor(dtype, shape)
if context.get_context("mode") == context.PYNATIVE_MODE:
x.init_data()
else:
x.default_input = initializer(init, shape=shape, dtype=dtype)



+ 0
- 1
mindspore/nn/cell.py View File

@@ -202,7 +202,6 @@ class Cell:
if context.get_context("mode") == context.GRAPH_MODE:
out = self.compile_and_run(*inputs)
return out
self.init_parameters_data()
orign_grad = []
if self.requires_grad is True:
_pynative_exec.set_grad_flag(True)


Loading…
Cancel
Save