| @@ -23,7 +23,8 @@ class AppendToTargetOrInputException(Exception): | |||
| self.field_name = field_name # 标示当前field的名称 | |||
| class FieldArray: | |||
| def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False): | |||
| def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False, | |||
| use_1st_ins_infer_dim_type=True): | |||
| if len(content)==0: | |||
| raise RuntimeError("Empty fieldarray is not allowed.") | |||
| _content = content | |||
| @@ -38,6 +39,7 @@ class FieldArray: | |||
| # 根据input的情况设置input,target等 | |||
| self._cell_ndim = None # 多少维度 | |||
| self.dtype = None # 最内层的element都是什么类型的 | |||
| self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | |||
| self._is_input = False | |||
| self._is_target = False | |||
| @@ -77,7 +79,7 @@ class FieldArray: | |||
| if value is True and \ | |||
| self._is_target is False and \ | |||
| self._ignore_type is False: | |||
| self._check_dtype_and_ndim() | |||
| self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type) | |||
| if value is False and self._is_target is False: | |||
| self.dtype = None | |||
| self._cell_ndim = None | |||
| @@ -95,32 +97,34 @@ class FieldArray: | |||
| if value is True and \ | |||
| self._is_input is False and \ | |||
| self._ignore_type is False: | |||
| self._check_dtype_and_ndim() | |||
| self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type) | |||
| if value is False and self._is_input is False: | |||
| self.dtype = None | |||
| self._cell_ndim = None | |||
| self._is_target = value | |||
| def _check_dtype_and_ndim(self): | |||
| def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True): | |||
| """ | |||
| 检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 | |||
| 通过将直接报错. | |||
| :param bool only_check_1st_ins_dim_type: 是否只检查第一个元素的type和dim | |||
| :return: | |||
| """ | |||
| cell_0 = self.content[0] | |||
| index = 0 | |||
| try: | |||
| type_0, dim_0 = _get_ele_type_and_dim(cell_0) | |||
| for cell in self.content[1:]: | |||
| index += 1 | |||
| type_i, dim_i = _get_ele_type_and_dim(cell) | |||
| if type_i!=type_0: | |||
| raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." | |||
| ".".format(type_i, index, type_0)) | |||
| if dim_0!=dim_i: | |||
| raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " | |||
| "dimension:{}.".format(dim_i, index, dim_0)) | |||
| if not only_check_1st_ins_dim_type: | |||
| for cell in self.content[1:]: | |||
| index += 1 | |||
| type_i, dim_i = _get_ele_type_and_dim(cell) | |||
| if type_i!=type_0: | |||
| raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." | |||
| ".".format(type_i, index, type_0)) | |||
| if dim_0!=dim_i: | |||
| raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " | |||
| "dimension:{}.".format(dim_i, index, dim_0)) | |||
| self._cell_ndim = dim_0 | |||
| self.dtype = type_0 | |||
| except SetInputOrTargetException as e: | |||
| @@ -132,7 +136,7 @@ class FieldArray: | |||
| :param val: 把该val append到fieldarray。 | |||
| :return: | |||
| """ | |||
| if (self._is_target or self._is_input) and self._ignore_type is False: | |||
| if (self._is_target or self._is_input) and self._ignore_type is False and not self._use_1st_ins_infer_dim_type: | |||
| type_, dim_ = _get_ele_type_and_dim(val) | |||
| if self.dtype!=type_: | |||
| raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with " | |||
| @@ -144,6 +148,14 @@ class FieldArray: | |||
| else: | |||
| self.content.append(val) | |||
| def pop(self, index): | |||
| """ | |||
| 删除该field中index处的元素 | |||
| :param int index: 从0开始的数据下标。 | |||
| :return: | |||
| """ | |||
| self.content.pop(index) | |||
| def __getitem__(self, indices): | |||
| return self.get(indices, pad=False) | |||