Browse Source

resolve optimizer tuple inputs issue

tags/v0.7.0-beta
kingfo 5 years ago
parent
commit
ca7f4df7b2
1 changed files with 12 additions and 7 deletions
  1. +12
    -7
      mindspore/nn/cell.py

+ 12
- 7
mindspore/nn/cell.py View File

@@ -223,6 +223,15 @@ class Cell:
else:
object.__delattr__(self, name)

def cast_inputs(self, inputs, dst_type):
res = list()
for item in inputs:
if isinstance(item, tuple):
res.append(self.cast_inputs(item, dst_type))
else:
res.append(cast(item, dst_type))
return tuple(res)

def __call__(self, *inputs, **kwargs):
if context.get_context("mode") == context.GRAPH_MODE:
if kwargs:
@@ -250,14 +259,10 @@ class Cell:
cast_inputs = list()
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float16))
cast_inputs = self.cast_inputs(inputs, mstype.float16)
if self._mindspore_flags.get('fp32'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float32))
if cast_inputs:
cast_inputs = tuple(cast_inputs)
else:
cast_inputs = self.cast_inputs(inputs, mstype.float32)
if not cast_inputs:
cast_inputs = inputs
if self.enable_hook:
output = self._hook_construct(*cast_inputs, **kwargs)


Loading…
Cancel
Save