|
|
|
@@ -223,6 +223,15 @@ class Cell: |
|
|
|
else: |
|
|
|
object.__delattr__(self, name) |
|
|
|
|
|
|
|
def cast_inputs(self, inputs, dst_type): |
|
|
|
res = list() |
|
|
|
for item in inputs: |
|
|
|
if isinstance(item, tuple): |
|
|
|
res.append(self.cast_inputs(item, dst_type)) |
|
|
|
else: |
|
|
|
res.append(cast(item, dst_type)) |
|
|
|
return tuple(res) |
|
|
|
|
|
|
|
def __call__(self, *inputs, **kwargs): |
|
|
|
if context.get_context("mode") == context.GRAPH_MODE: |
|
|
|
if kwargs: |
|
|
|
@@ -250,14 +259,10 @@ class Cell: |
|
|
|
cast_inputs = list() |
|
|
|
if hasattr(self, "_mindspore_flags"): |
|
|
|
if self._mindspore_flags.get('fp16'): |
|
|
|
for item in inputs: |
|
|
|
cast_inputs.append(cast(item, mstype.float16)) |
|
|
|
cast_inputs = self.cast_inputs(inputs, mstype.float16) |
|
|
|
if self._mindspore_flags.get('fp32'): |
|
|
|
for item in inputs: |
|
|
|
cast_inputs.append(cast(item, mstype.float32)) |
|
|
|
if cast_inputs: |
|
|
|
cast_inputs = tuple(cast_inputs) |
|
|
|
else: |
|
|
|
cast_inputs = self.cast_inputs(inputs, mstype.float32) |
|
|
|
if not cast_inputs: |
|
|
|
cast_inputs = inputs |
|
|
|
if self.enable_hook: |
|
|
|
output = self._hook_construct(*cast_inputs, **kwargs) |
|
|
|
|