|
|
|
@@ -370,7 +370,7 @@ class CheckTensor(PrimitiveWithInfer): |
|
|
|
def __call__(self, x, name): |
|
|
|
if isinstance(x, Tensor): |
|
|
|
return x |
|
|
|
raise TypeError(f"For {name}, input type should be a Tensor.") |
|
|
|
raise TypeError(f"For {name}, input type should be a Tensor or Parameter.") |
|
|
|
|
|
|
|
def common_dtype(arg_a, name_a, arg_b, name_b, hint_type): |
|
|
|
""" |
|
|
|
|