diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index bba854f5..d7d3bb8b 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -23,7 +23,8 @@ class AppendToTargetOrInputException(Exception): self.field_name = field_name # 标示当前field的名称 class FieldArray: - def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False): + def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False, + use_1st_ins_infer_dim_type=True): if len(content)==0: raise RuntimeError("Empty fieldarray is not allowed.") _content = content @@ -38,6 +39,7 @@ class FieldArray: # 根据input的情况设置input,target等 self._cell_ndim = None # 多少维度 self.dtype = None # 最内层的element都是什么类型的 + self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) self._is_input = False self._is_target = False @@ -77,7 +79,7 @@ class FieldArray: if value is True and \ self._is_target is False and \ self._ignore_type is False: - self._check_dtype_and_ndim() + self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type) if value is False and self._is_target is False: self.dtype = None self._cell_ndim = None @@ -95,32 +97,34 @@ class FieldArray: if value is True and \ self._is_input is False and \ self._ignore_type is False: - self._check_dtype_and_ndim() + self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type) if value is False and self._is_input is False: self.dtype = None self._cell_ndim = None self._is_target = value - def _check_dtype_and_ndim(self): + def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True): """ 检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 通过将直接报错. + :param bool only_check_1st_ins_dim_type: 是否只检查第一个元素的type和dim :return: """ cell_0 = self.content[0] index = 0 try: type_0, dim_0 = _get_ele_type_and_dim(cell_0) - for cell in self.content[1:]: - index += 1 - type_i, dim_i = _get_ele_type_and_dim(cell) - if type_i!=type_0: - raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." - ".".format(type_i, index, type_0)) - if dim_0!=dim_i: - raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " - "dimension:{}.".format(dim_i, index, dim_0)) + if not only_check_1st_ins_dim_type: + for cell in self.content[1:]: + index += 1 + type_i, dim_i = _get_ele_type_and_dim(cell) + if type_i!=type_0: + raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." + ".".format(type_i, index, type_0)) + if dim_0!=dim_i: + raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " + "dimension:{}.".format(dim_i, index, dim_0)) self._cell_ndim = dim_0 self.dtype = type_0 except SetInputOrTargetException as e: @@ -132,7 +136,7 @@ class FieldArray: :param val: 把该val append到fieldarray。 :return: """ - if (self._is_target or self._is_input) and self._ignore_type is False: + if (self._is_target or self._is_input) and self._ignore_type is False and not self._use_1st_ins_infer_dim_type: type_, dim_ = _get_ele_type_and_dim(val) if self.dtype!=type_: raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with " @@ -144,6 +148,14 @@ class FieldArray: else: self.content.append(val) + def pop(self, index): + """ + 删除该field中index处的元素 + :param int index: 从0开始的数据下标。 + :return: + """ + self.content.pop(index) + def __getitem__(self, indices): return self.get(indices, pad=False)