Browse Source

fix warning

master
行嗔 3 years ago
parent
commit
202bd4c5a8
20 changed files with 294 additions and 63 deletions
  1. +1
    -1
      modelscope/models/multi_modal/ofa/generate/search.py
  2. +3
    -8
      modelscope/models/multi_modal/ofa/generate/sequence_generator.py
  3. +2
    -1
      modelscope/models/multi_modal/ofa/modeling_ofa.py
  4. +2
    -2
      modelscope/models/multi_modal/ofa/utils/constant.py
  5. +16
    -3
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  6. +17
    -3
      modelscope/preprocessors/multi_modal.py
  7. +2
    -1
      modelscope/preprocessors/ofa/base.py
  8. +10
    -6
      modelscope/preprocessors/ofa/image_captioning.py
  9. +8
    -6
      modelscope/preprocessors/ofa/image_classification.py
  10. +6
    -4
      modelscope/preprocessors/ofa/summarization.py
  11. +7
    -4
      modelscope/preprocessors/ofa/text_classification.py
  12. +6
    -4
      modelscope/preprocessors/ofa/text_to_image_synthesis.py
  13. +5
    -2
      modelscope/preprocessors/ofa/utils/collate.py
  14. +9
    -6
      modelscope/preprocessors/ofa/visual_entailment.py
  15. +9
    -6
      modelscope/preprocessors/ofa/visual_grounding.py
  16. +8
    -6
      modelscope/preprocessors/ofa/visual_question_answering.py
  17. +0
    -0
      modelscope/trainers/multi_modal/ofa/__init__.py
  18. +131
    -0
      modelscope/trainers/multi_modal/ofa/ofa_file_dataset.py
  19. +0
    -0
      modelscope/trainers/multi_modal/ofa/ofa_trainer.py
  20. +52
    -0
      modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py

+ 1
- 1
modelscope/models/multi_modal/ofa/generate/search.py View File

@@ -148,7 +148,7 @@ class BeamSearch(Search):
scores_buf = top_prediction[0]
indices_buf = top_prediction[1]
# Project back into relative indices and beams
beams_buf = indices_buf // vocab_size
beams_buf = torch.div(indices_buf, vocab_size, rounding_mode='floor')
indices_buf = indices_buf.fmod(vocab_size)

# At this point, beams_buf and indices_buf are single-dim and contain relative indices


+ 3
- 8
modelscope/models/multi_modal/ofa/generate/sequence_generator.py View File

@@ -385,12 +385,7 @@ class SequenceGenerator(nn.Module):
attn = torch.empty(bsz * beam_size,
avg_attn_scores.size(1),
max_len + 2).to(scores)
# print("+++++++ debug attention shape +++++++")
# print("attn", attn.shape)
# print("avg_attn_scores", avg_attn_scores.shape)
attn[:, :, step + 1].copy_(avg_attn_scores)
# print("attn[:, :, step + 1]", attn[:, :, step + 1].shape)
# print("attn", attn.shape)

scores = scores.type_as(lprobs)
eos_bbsz_idx = torch.empty(0).to(
@@ -403,7 +398,8 @@ class SequenceGenerator(nn.Module):
if self.should_set_src_lengths:
self.search.set_src_lengths(src_lengths)

if self.repeat_ngram_blocker is not None:
if self.repeat_ngram_blocker is not None and step > prefix_tokens.size(
1):
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz,
beam_size, step)

@@ -415,7 +411,6 @@ class SequenceGenerator(nn.Module):
tokens[:, :step + 1],
original_batch_idxs,
)

# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
@@ -671,7 +666,7 @@ class SequenceGenerator(nn.Module):
cum_unfin.append(prev)
cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx)

unfin_idx = bbsz_idx // beam_size
unfin_idx = torch.div(bbsz_idx, beam_size, rounding_mode='floor')
sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx)

# Create a set of "{sent}{unfin_idx}", where


+ 2
- 1
modelscope/models/multi_modal/ofa/modeling_ofa.py View File

@@ -114,7 +114,8 @@ def make_image_bucket_position(bucket_size, num_relative_distance):
"""
coords_h = torch.arange(bucket_size)
coords_w = torch.arange(bucket_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords = torch.stack(torch.meshgrid([coords_h, coords_w],
indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - \
coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww


+ 2
- 2
modelscope/models/multi_modal/ofa/utils/constant.py View File

@@ -7,7 +7,7 @@ OFA_TASK_KEY_MAPPING = {
Tasks.summarization: OutputKeys.TEXT,
Tasks.visual_question_answering: OutputKeys.TEXT,
Tasks.visual_grounding: OutputKeys.BOXES,
Tasks.text_classification: (OutputKeys.SCORES, OutputKeys.LABELS),
Tasks.text_classification: OutputKeys.LABELS,
Tasks.image_classification: OutputKeys.LABELS,
Tasks.visual_entailment: (OutputKeys.SCORES, OutputKeys.LABELS),
Tasks.visual_entailment: OutputKeys.LABELS,
}

+ 16
- 3
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -127,10 +127,23 @@ class OfaForAllTasks(TorchModel):
return input

def _text_gen_inference(self, input):
import pdb
pdb.set_trace()
input = move_to_device(input, self._device)
gen_output = self.generator.generate([self.model], input)
gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))]
result = self.tokenizer.batch_decode(gen, skip_special_tokens=True)
if 'prefix_tokens' in input:
gen_output = self.generator.generate(
[self.model], input, prefix_tokens=input['prefix_tokens'])
else:
gen_output = self.generator.generate([self.model], input)
gen_l = list()
for i in range(len(gen_output)):
if 'prefix_tokens' in input:
prefix_tokens = input['prefix_tokens']
gen_l.append(
gen_output[i][0]['tokens'][len(prefix_tokens[i]):])
else:
gen_l.append(gen_output[i][0]['tokens'])
result = self.tokenizer.batch_decode(gen_l, skip_special_tokens=True)
# text generation tasks have no score
ret = {OFA_TASK_KEY_MAPPING[self.cfg.task]: result}
if self.cfg.task.endswith('classification'):


+ 17
- 3
modelscope/preprocessors/multi_modal.py View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from io import BytesIO
from typing import Any, Dict, List, Union

import torch
@@ -8,6 +9,7 @@ from PIL import Image
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Preprocessors
from modelscope.pipelines.base import Input
from modelscope.preprocessors.image import load_image
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModelFile, Tasks
from .base import Preprocessor
@@ -71,20 +73,32 @@ class OfaPreprocessor(Preprocessor):
data[key] = item
return data

def _compatible_with_pretrain(self, data):
if 'image' in data and self.cfg.model.get('type', None) == 'ofa':
image = load_image(data['image'])
img_buffer = BytesIO()
image.save(img_buffer, format='JPEG')
data['image'] = Image.open(img_buffer)
return data

def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args,
**kwargs) -> Dict[str, Any]:
if isinstance(input, dict):
data = input
else:
data = self._build_dict(input)
data = self._compatible_with_pretrain(data)
sample = self.preprocess(data)
str_data = dict()
for k, v in data.items():
str_data[k] = str(v)
sample['sample'] = str_data
return collate_fn([sample],
pad_idx=self.tokenizer.pad_token_id,
eos_idx=self.tokenizer.eos_token_id)
if kwargs.get('no_collate', None):
return sample
else:
return collate_fn([sample],
pad_idx=self.tokenizer.pad_token_id,
eos_idx=self.tokenizer.eos_token_id)


@PREPROCESSORS.register_module(


+ 2
- 1
modelscope/preprocessors/ofa/base.py View File

@@ -13,7 +13,7 @@ from .utils.random_help import set_torch_seed

class OfaBasePreprocessor:

def __init__(self, cfg, model_dir):
def __init__(self, cfg, model_dir, split, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
@@ -76,6 +76,7 @@ class OfaBasePreprocessor:
text,
max_length=self.max_src_length,
add_special_tokens=False,
truncation=True,
return_tensors='pt')['input_ids'].squeeze(0)
if add_bos:
inputs = torch.cat([self.bos_item, inputs])


+ 10
- 6
modelscope/preprocessors/ofa/image_captioning.py View File

@@ -6,24 +6,28 @@ from PIL import Image
from torchvision import transforms

from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor


class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path
def __init__(self, cfg, model_dir, split, *args, **kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
model_dir (str): model path,
split: data phase
"""
super(OfaImageCaptioningPreprocessor, self).__init__(cfg, model_dir)
super(OfaImageCaptioningPreprocessor,
self).__init__(cfg, model_dir, split, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize((self.patch_image_size, self.patch_image_size),
interpolation=Image.BICUBIC),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])


+ 8
- 6
modelscope/preprocessors/ofa/image_classification.py View File

@@ -11,20 +11,22 @@ from .base import OfaBasePreprocessor

class OfaImageClassificationPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path
def __init__(self, cfg, model_dir, split, *args, **kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
model_dir (str): model path,
split: data phase
"""
super(OfaImageClassificationPreprocessor,
self).__init__(cfg, model_dir)
self).__init__(cfg, model_dir, split, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize((self.patch_image_size, self.patch_image_size),
interpolation=Image.BICUBIC),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])


+ 6
- 4
modelscope/preprocessors/ofa/summarization.py View File

@@ -6,14 +6,16 @@ from .base import OfaBasePreprocessor

class OfaSummarizationPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path
def __init__(self, cfg, model_dir, split, *args, **kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
model_dir (str): model path,
split: data phase
"""
super(OfaSummarizationPreprocessor, self).__init__(cfg, model_dir)
super(OfaSummarizationPreprocessor,
self).__init__(cfg, model_dir, split, *args, **kwargs)

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
source = super().pre_caption(


+ 7
- 4
modelscope/preprocessors/ofa/text_classification.py View File

@@ -6,14 +6,16 @@ from .base import OfaBasePreprocessor

class OfaTextClassificationPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path
def __init__(self, cfg, model_dir, split, *args, **kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
model_dir (str): model path,
split: data phase
"""
super(OfaTextClassificationPreprocessor, self).__init__(cfg, model_dir)
super(OfaTextClassificationPreprocessor,
self).__init__(cfg, model_dir, split, *args, **kwargs)

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
text1 = ' '.join(
@@ -34,5 +36,6 @@ class OfaTextClassificationPreprocessor(OfaBasePreprocessor):
sample = {
'source': inputs,
'decoder_prompt': decoder_prompt,
'prefix_token': decoder_prompt[:-1],
}
return sample

+ 6
- 4
modelscope/preprocessors/ofa/text_to_image_synthesis.py View File

@@ -8,14 +8,16 @@ from .base import OfaBasePreprocessor

class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path
def __init__(self, cfg, model_dir, split, *args, **kwargs):
"""preprocess the data

Args:
model_dir (str): model path
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path,
split: data phase
"""
super(OfaTextToImageSynthesisPreprocessor,
self).__init__(cfg, model_dir)
self).__init__(cfg, model_dir, split, *args, **kwargs)
self.max_src_length = 64

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:


+ 5
- 2
modelscope/preprocessors/ofa/utils/collate.py View File

@@ -50,8 +50,11 @@ def collate_fn(samples, pad_idx, eos_idx):
if samples[0].get('constraint_mask', None) is not None:
batch['constraint_masks'] = merge('constraint_mask')
if samples[0].get('decoder_prompt', None) is not None:
batch['decoder_prompts'] = np.array(
[s['decoder_prompt'].tolist() for s in samples])
batch['decoder_prompts'] = torch.stack(
[s['decoder_prompt'] for s in samples], dim=0)
if samples[0].get('prefix_token', None) is not None:
batch['prefix_tokens'] = torch.stack(
[s['prefix_token'] for s in samples], dim=0)
# For detection and visual grounding
if samples[0].get('w_resize_ratio', None) is not None:
batch['w_resize_ratios'] = torch.stack(


+ 9
- 6
modelscope/preprocessors/ofa/visual_entailment.py View File

@@ -11,19 +11,22 @@ from .base import OfaBasePreprocessor

class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path
def __init__(self, cfg, model_dir, split, *args, **kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
model_dir (str): model path,
split: data phase
"""
super(OfaVisualEntailmentPreprocessor, self).__init__(cfg, model_dir)
super(OfaVisualEntailmentPreprocessor,
self).__init__(cfg, model_dir, split, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize((self.patch_image_size, self.patch_image_size),
interpolation=Image.BICUBIC),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])


+ 9
- 6
modelscope/preprocessors/ofa/visual_grounding.py View File

@@ -11,19 +11,22 @@ from .base import OfaBasePreprocessor

class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path
def __init__(self, cfg, model_dir, split, *args, **kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
model_dir (str): model path,
split: data phase
"""
super(OfaVisualGroundingPreprocessor, self).__init__(cfg, model_dir)
super(OfaVisualGroundingPreprocessor,
self).__init__(cfg, model_dir, split, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize((self.patch_image_size, self.patch_image_size),
interpolation=Image.BICUBIC),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])


+ 8
- 6
modelscope/preprocessors/ofa/visual_question_answering.py View File

@@ -11,20 +11,22 @@ from .base import OfaBasePreprocessor

class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path
def __init__(self, cfg, model_dir, split, *args, **kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
model_dir (str): model path,
split: data phase
"""
super(OfaVisualQuestionAnsweringPreprocessor,
self).__init__(cfg, model_dir)
self).__init__(cfg, model_dir, split, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize((self.patch_image_size, self.patch_image_size),
interpolation=Image.BICUBIC),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])


+ 0
- 0
modelscope/trainers/multi_modal/ofa/__init__.py View File


+ 131
- 0
modelscope/trainers/multi_modal/ofa/ofa_file_dataset.py View File

@@ -0,0 +1,131 @@
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import os
import pickle

import torch


class OFAFileDataset:

def __init__(self,
file_path,
selected_col_ids=None,
dtypes=None,
separator='\t',
cached_index=False):
self.file_path = file_path
assert os.path.exists(
self.file_path), 'Error: The local datafile {} not exists!'.format(
self.file_path)

self.separator = separator
if selected_col_ids is None:
# default to all fields
self.selected_col_ids = list(
range(
len(
open(self.file_path).readline().rstrip('\n').split(
self.separator))))
else:
self.selected_col_ids = [
int(col_id) for col_id in selected_col_ids.split(',')
]
if dtypes is None:
# default to str
self.dtypes = [str for col_id in self.selected_col_ids]
else:
self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(',')]
assert len(self.dtypes) == len(self.selected_col_ids)

self.data_cnt = 0
try:
self.slice_id = torch.distributed.get_rank()
self.slice_count = torch.distributed.get_world_size()
except Exception:
self.slice_id = 0
self.slice_count = 1
self.cached_index = cached_index
self._init_seek_index()
self._reader = self._get_reader()
print('file {} slice_id {} row count {} total row count {}'.format(
self.file_path, self.slice_id, self.row_count,
self.total_row_count))

def _init_seek_index(self):
if self.cached_index:
cache_path = '{}.index'.format(self.file_path)
assert os.path.exists(
cache_path), 'cache file {} not exists!'.format(cache_path)
self.total_row_count, self.lineid_to_offset = pickle.load(
open(cache_path, 'rb'))
print(
'local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping'
.format(self.file_path, self.slice_id))
else:
# make an iteration over the file to get row_count and line_idx-to-offset mapping
fp = open(self.file_path, 'r')
print(
'local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping'
.format(self.file_path, self.slice_id))
self.total_row_count = 0
offset = 0
self.lineid_to_offset = []
for line in fp:
self.lineid_to_offset.append(offset)
self.total_row_count += 1
offset += len(line.encode('utf-8'))
self._compute_start_pos_and_row_count()
print(
'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping'
.format(self.file_path, self.slice_id))

def _compute_start_pos_and_row_count(self):
self.row_count = self.total_row_count // self.slice_count
if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
self.row_count += 1
self.start_pos = self.row_count * self.slice_id
else:
self.start_pos = self.row_count * self.slice_id + (
self.total_row_count - self.row_count * self.slice_count)

def _get_reader(self):
fp = open(self.file_path, 'r')
fp.seek(self.lineid_to_offset[self.start_pos])
return fp

def _seek(self, offset=0):
try:
print('slice_id {} seek offset {}'.format(self.slice_id,
self.start_pos + offset))
self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
self.data_cnt = offset
except Exception:
print('slice_id {} seek offset {}'.format(self.slice_id, offset))
self._reader.seek(self.lineid_to_offset[offset])
self.data_cnt = offset

def __del__(self):
self._reader.close()

def __len__(self):
return self.row_count

def get_total_row_count(self):
return self.total_row_count

def __getitem__(self, index):
if self.data_cnt == self.row_count:
print('reach the end of datafile, start a new reader')
self.data_cnt = 0
self._reader = self._get_reader()
column_l = self._reader.readline().rstrip('\n').split(self.separator)
self.data_cnt += 1
column_l = [
dtype(column_l[col_id])
for col_id, dtype in zip(self.selected_col_ids, self.dtypes)
]
return column_l

+ 0
- 0
modelscope/trainers/multi_modal/ofa/ofa_trainer.py View File


+ 52
- 0
modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py View File

@@ -0,0 +1,52 @@
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
from os import path as osp

from torch.utils.data import Dataset

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.preprocessors.multi_modal import OfaPreprocessor
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks
from .ofa_file_dataset import OFAFileDataset


class OFADataset(Dataset):

def __init__(self,
model_dir,
file_path,
dtypes=None,
separator='\t',
cached_index=False,
split=ModeKeys.TRAIN,
**kwargs):
self.cfg = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
selected_col_ids = self.cfg.dataset.selected_col_ids
selected_col_keys = self.cfg.dataset.selected_col_keys

assert selected_col_ids is not None
assert selected_col_keys is not None
self.selected_col_key_l = selected_col_keys.split(',')
assert len(self.selected_col_key_l) == len(selected_col_ids.split(','))

self.dataset = OFAFileDataset(
file_path=file_path,
selected_col_ids=selected_col_ids,
dtypes=dtypes,
separator=separator,
cached_index=cached_index)
self.preprocessor = OfaPreprocessor(model_dir, split)

def __len__(self):
return len(self.dataset)

def __getitem__(self, index):
value_l = self.dataset[index]
data = dict()
for key, value in zip(self.selected_col_key_l, value_l):
data[key] = value
return self.preprocessor(data)

Loading…
Cancel
Save