|
|
|
@@ -77,15 +77,15 @@ class GNNFeatureTransform(nn.Cell): |
|
|
|
self.has_bias = check_bool(has_bias) |
|
|
|
|
|
|
|
if isinstance(weight_init, Tensor): |
|
|
|
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ |
|
|
|
weight_init.shape()[1] != in_channels: |
|
|
|
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ |
|
|
|
weight_init.shape[1] != in_channels: |
|
|
|
raise ValueError("weight_init shape error") |
|
|
|
|
|
|
|
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") |
|
|
|
|
|
|
|
if self.has_bias: |
|
|
|
if isinstance(bias_init, Tensor): |
|
|
|
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: |
|
|
|
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: |
|
|
|
raise ValueError("bias_init shape error") |
|
|
|
|
|
|
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") |
|
|
|
|