|
|
@@ -210,9 +210,9 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer): |
|
|
raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.") |
|
|
raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.") |
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, min_shape, max_shape): |
|
|
def infer_shape(self, x_shape, min_shape, max_shape): |
|
|
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) |
|
|
|
|
|
|
|
|
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) |
|
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) |
|
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) |
|
|
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name) |
|
|
|
|
|
|
|
|
validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name) |
|
|
self.check_broadcast(min_shape, x_shape) |
|
|
self.check_broadcast(min_shape, x_shape) |
|
|
return x_shape |
|
|
return x_shape |
|
|
|
|
|
|
|
|
@@ -273,10 +273,10 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer): |
|
|
raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.") |
|
|
raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.") |
|
|
|
|
|
|
|
|
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): |
|
|
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): |
|
|
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) |
|
|
|
|
|
|
|
|
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) |
|
|
validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) |
|
|
validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) |
|
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) |
|
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) |
|
|
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name) |
|
|
|
|
|
|
|
|
validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name) |
|
|
self.check_broadcast(min_shape, x_shape) |
|
|
self.check_broadcast(min_shape, x_shape) |
|
|
return x_shape, min_shape, max_shape |
|
|
return x_shape, min_shape, max_shape |
|
|
|
|
|
|
|
|
@@ -325,9 +325,9 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer): |
|
|
'narrow_range', narrow_range, (bool,), self.name) |
|
|
'narrow_range', narrow_range, (bool,), self.name) |
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, min_shape, max_shape): |
|
|
def infer_shape(self, x_shape, min_shape, max_shape): |
|
|
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) |
|
|
|
|
|
|
|
|
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) |
|
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) |
|
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) |
|
|
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name) |
|
|
|
|
|
|
|
|
validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name) |
|
|
validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name) |
|
|
validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name) |
|
|
return x_shape |
|
|
return x_shape |
|
|
|
|
|
|
|
|
@@ -382,10 +382,10 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer): |
|
|
'narrow_range', narrow_range, (bool,), self.name) |
|
|
'narrow_range', narrow_range, (bool,), self.name) |
|
|
|
|
|
|
|
|
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): |
|
|
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): |
|
|
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) |
|
|
|
|
|
|
|
|
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) |
|
|
validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) |
|
|
validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) |
|
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) |
|
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) |
|
|
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name) |
|
|
|
|
|
|
|
|
validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name) |
|
|
validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name) |
|
|
validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name) |
|
|
return x_shape, min_shape, max_shape |
|
|
return x_shape, min_shape, max_shape |
|
|
|
|
|
|
|
|
|