Browse Source

!2392 don't change shape of Initializer when init slice of a Parameter

Merge pull request !2392 from yihuaijie/master
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
f3f95b255b
2 changed files with 8 additions and 5 deletions
  1. +7
    -3
      mindspore/common/initializer.py
  2. +1
    -2
      mindspore/common/parameter.py

+ 7
- 3
mindspore/common/initializer.py View File

@@ -64,7 +64,7 @@ class Initializer:
def dtype(self, dtype):
self._dtype = dtype

def to_tensor(self, slice_index=None):
def to_tensor(self, slice_index=None, shape=None):
"""
Get the tensor format data of this Initializer.

@@ -72,12 +72,16 @@ class Initializer:
slice_index (int): Slice index of a parameter's slices.
Used when initialize a slice of a parameter, it guarantee that
devices use the same slice can generate the same tensor.
shape (list[int]): Shape of the slice, used when initialize a slice of the parameter.
"""
arr = None
if shape is None:
shape = self.shape

try:
arr = np.ndarray(self.shape)
arr = np.ndarray(shape)
except ValueError:
msg = "Error shape={}".format(self.shape)
msg = "Error shape={}".format(shape)
logger.error(msg)
raise ValueError(msg)



+ 1
- 2
mindspore/common/parameter.py View File

@@ -249,9 +249,8 @@ class Parameter:
if len(layout) != 3:
raise ValueError("The length of layout must be 3! layout is {}."
.format(layout))
self.init_mode.shape = layout[2]
slice_index = int(_get_slice_index(layout[0], layout[1]))
self.default_input = self.init_mode.to_tensor(slice_index)
self.default_input = self.init_mode.to_tensor(slice_index, layout[2])
else:
self.default_input = self.init_mode.to_tensor()



Loading…
Cancel
Save