Browse Source

opt multi fun

tags/v1.1.0
chujinjin 5 years ago
parent
commit
3c269a059e
1 changed files with 3 additions and 0 deletions
  1. +3
    -0
      mindspore/ops/composite/base.py

+ 3
- 0
mindspore/ops/composite/base.py View File

@@ -395,6 +395,9 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),))

def __call__(self, *args):
if len(self.entries) == 1:
output = self.entries[0][1](*args)
return output
types = tuple(map(mstype.get_py_obj_dtype, args))
for sigs, fn in self.entries:
if len(sigs) != len(types):


Loading…
Cancel
Save