|
|
|
@@ -47,8 +47,10 @@ class Expander: |
|
|
|
graph_scope.set_input(*self.inputs) |
|
|
|
outputs = self._expand(graph_builder) |
|
|
|
if isinstance(outputs, (list, tuple)): |
|
|
|
self._check_output_same(outputs) |
|
|
|
graph_scope.set_output(*outputs) |
|
|
|
else: |
|
|
|
self._check_output_same([outputs]) |
|
|
|
graph_scope.set_output(outputs) |
|
|
|
|
|
|
|
graph = graph_builder.get()[0] |
|
|
|
@@ -58,6 +60,18 @@ class Expander: |
|
|
|
def _check(self): |
|
|
|
"""Check inputs""" |
|
|
|
|
|
|
|
def _check_output_same(self, outputs): |
|
|
|
for index, value in enumerate(self.outputs): |
|
|
|
if list(outputs[index].shape) != list(value['shape']): |
|
|
|
raise GKException("{} 's output shape {} is wrong. Expected:{}".format( |
|
|
|
self.__class__.__name__, list(outputs[index].shape), list(value['shape']))) |
|
|
|
if outputs[index].dtype != value['data_type']: |
|
|
|
raise GKException("{} 's output data_type {} is wrong. Expected: {}".format( |
|
|
|
self.__class__.__name__, outputs[index].dtype, value['data_type'])) |
|
|
|
if outputs[index].data_format != value['format']: |
|
|
|
raise GKException("{} 's output format {} is wrong. Expected: {}".format( |
|
|
|
self.__class__.__name__, outputs[index].data_format, value['format'])) |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
def _expand(self, graph_builder): |
|
|
|
"""Expand operator, this function should be overridden in subclass""" |
|
|
|
|