|
|
|
@@ -7,11 +7,14 @@ from itertools import chain |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from ....utils.logger import get_logger |
|
|
|
from ....utils.nlp.space import ontology, utils |
|
|
|
from ....utils.nlp.space.db_ops import MultiWozDB |
|
|
|
from ....utils.nlp.space.utils import list2np |
|
|
|
from ..tokenizer import Tokenizer |
|
|
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
|
|
|
class BPETextField(object): |
|
|
|
|
|
|
|
@@ -306,7 +309,21 @@ class MultiWOZBPETextField(BPETextField): |
|
|
|
|
|
|
|
def __init__(self, model_dir, config): |
|
|
|
super(MultiWOZBPETextField, self).__init__(config) |
|
|
|
|
|
|
|
import spacy |
|
|
|
try: |
|
|
|
import en_core_web_sm |
|
|
|
except ImportError: |
|
|
|
logger.warn('Miss module en_core_web_sm!') |
|
|
|
logger.warn('We will download en_core_web_sm automatically.') |
|
|
|
try: |
|
|
|
spacy.cli.download('en_core_web_sm') |
|
|
|
except Exception as e: |
|
|
|
logger.error(e) |
|
|
|
raise ImportError( |
|
|
|
'Download en_core_web_sm error. ' |
|
|
|
'Please use \'python -m spacy download en_core_web_sm\' to download it by yourself!' |
|
|
|
) |
|
|
|
self.nlp = spacy.load('en_core_web_sm') |
|
|
|
|
|
|
|
self.db = MultiWozDB( |
|
|
|
|