|
|
|
@@ -18,6 +18,6 @@ class SbertForSentenceSimilarity(SbertForSequenceClassificationBase): |
|
|
|
model_cls (Optional[Any], optional): model loader, if None, use the |
|
|
|
default loader to load model weights, by default None. |
|
|
|
""" |
|
|
|
super().__init__(model_dir, *args, **kwargs) |
|
|
|
super().__init__(model_dir, *args, model_args={"num_labels": 2}, **kwargs) |
|
|
|
self.model_dir = model_dir |
|
|
|
assert self.model.config.num_labels == 2 |