|
|
|
@@ -134,21 +134,23 @@ class CompositeGraph: |
|
|
|
return red_axis |
|
|
|
|
|
|
|
attr = {} |
|
|
|
if op['name'] not in ('ReduceSum', 'ReduceMax', 'ReduceMin'): |
|
|
|
return attr |
|
|
|
for a in op['attr']: |
|
|
|
if a['name'] == 'axis': |
|
|
|
red_axis, dim_size = [], len(inputs[0].shape) |
|
|
|
if not a['value']: |
|
|
|
red_axis = _get_axis_while_none(inputs[0].shape, output.shape) |
|
|
|
else: |
|
|
|
if isinstance(a['value'], int): |
|
|
|
a['value'] = [a['value']] |
|
|
|
for i in a['value']: |
|
|
|
red_axis.append(i if i >= 0 else dim_size + i) |
|
|
|
attr['reduce_axis'] = red_axis |
|
|
|
if a['name'] == "reduce_output_fuse": |
|
|
|
attr['reduce_output_fuse'] = a['value'] |
|
|
|
if op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'): |
|
|
|
for a in op['attr']: |
|
|
|
if a['name'] == 'axis': |
|
|
|
red_axis, dim_size = [], len(inputs[0].shape) |
|
|
|
if not a['value']: |
|
|
|
red_axis = _get_axis_while_none(inputs[0].shape, output.shape) |
|
|
|
else: |
|
|
|
if isinstance(a['value'], int): |
|
|
|
a['value'] = [a['value']] |
|
|
|
for i in a['value']: |
|
|
|
red_axis.append(i if i >= 0 else dim_size + i) |
|
|
|
attr['reduce_axis'] = red_axis |
|
|
|
if a['name'] == "reduce_output_fuse": |
|
|
|
attr['reduce_output_fuse'] = a['value'] |
|
|
|
elif op['attr']: |
|
|
|
for a in op['attr']: |
|
|
|
attr[a['name']] = a['value'] |
|
|
|
return attr |
|
|
|
|
|
|
|
builder = GraphBuilder() |
|
|
|
|