Browse Source

fix bug in tensor store

tags/v1.1.0
Wei Luning 5 years ago
parent
commit
f403544af2
1 changed files with 7 additions and 2 deletions
  1. +7
    -2
      mindspore/nn/cell.py

+ 7
- 2
mindspore/nn/cell.py View File

@@ -228,11 +228,12 @@ class Cell(Cell_):
cells = self.__dict__['_cells']
if name in cells:
return cells[name]
if context.get_context("mode") == context.PYNATIVE_MODE and '_params_list' in self.__dict__:
params_list = self.__dict__['_params_list']
if '_tensor_list' in self.__dict__:
tensor_list = self.__dict__['_tensor_list']
if name in tensor_list:
return self.cast_param(tensor_list[name])
if '_params_list' in self.__dict__:
params_list = self.__dict__['_params_list']
if name in params_list:
para_list = params_list[name]
cast_list = list()
@@ -253,6 +254,10 @@ class Cell(Cell_):
elif name in self._cells:
del self._cells[name]
else:
if '_params_list' in self.__dict__ and name in self._params_list:
del self._params_list[name]
elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
del self._tensor_list[name]
object.__delattr__(self, name)
self._attr_synced = False



Loading…
Cancel
Save