|
|
|
@@ -154,14 +154,6 @@ class Merge(PrimitiveWithInfer): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def infer_shape(self, inputs): |
|
|
|
validator.check_integer('inputs len', len(inputs), 0, Rel.GT, self.name) |
|
|
|
input_0 = inputs[0] |
|
|
|
|
|
|
|
for i in range(1, len(inputs)): |
|
|
|
if inputs[i] != input_0: |
|
|
|
raise ValueError(f"For \'{self.name}\', the shape of {i}th input should be same as " |
|
|
|
f"first input {input_0}, but got {inputs[i]}.") |
|
|
|
|
|
|
|
return (inputs[0], [1]) |
|
|
|
|
|
|
|
def infer_dtype(self, inputs): |
|
|
|
|