|
|
|
@@ -775,7 +775,7 @@ class Tril(Cell): |
|
|
|
|
|
|
|
def construct(self, x, k=0): |
|
|
|
assist = tril(x.shape, self.dtype(x), k) |
|
|
|
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32)) |
|
|
|
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32)) |
|
|
|
return self.cast(result, self.dtype(x)) |
|
|
|
|
|
|
|
|
|
|
|
@@ -817,7 +817,7 @@ class Triu(Cell): |
|
|
|
|
|
|
|
def construct(self, x, k=0): |
|
|
|
assist = triu(x.shape, self.dtype(x), k) |
|
|
|
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32)) |
|
|
|
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32)) |
|
|
|
return self.cast(result, self.dtype(x)) |
|
|
|
|
|
|
|
|
|
|
|
|