Former-commit-id:masterb0d35d9aea[formerly28b7582d25] [formerly632f8af13a[formerly9eb3583256]] [formerlyab6f5a8a9c[formerly2720e70c78] [formerlyc1f615a3c5[formerly21840b3787]]] [formerly2184b60ad6[formerly2dae6c4ed4] [formerly6a577865b8[formerlyee28aaed84]] [formerly72078d88c8[formerlyab1fb9d512] [formerly476600a47c[formerlyf5a6205bb5]]]] [formerlye5c2c3deef[formerly8585dc41a9] [formerly9cc7fe2088[formerly1d2104316f]] [formerlyb51e614a11[formerlyf9891a191d] [formerly048aa2f114[formerlyd4de64574b]]] [formerlyff66f55a51[formerly0b691b9a7f] [formerlye64e8dd253[formerlycc45939e01]] [formerlycdad10712a[formerly2789e20e79] [formerly25924f293c[formerly32997accab]]]]] [formerlyf43b431040[formerly95815d02ca] [formerlyfe9bd45d44[formerly6daf0aa73e]] [formerly61ab30c9a3[formerlya13c6e23b4] [formerly86fa5919ee[formerly1e49e1a303]]] [formerly69c5bc967a[formerlyab82915cbd] [formerlyf8057c3b14[formerly5232f34578]] [formerly671c54e952[formerly6454a28f26] [formerly3db6ff66b9[formerlyaa8c7fe127]]]] [formerly86b2ec6b84[formerlyf35c344efe] [formerlyd5616f66cd[formerly98c9dca7da]] [formerlya7dcc62bc5[formerly4ef4fa0c98] [formerly55f670b9ae[formerly1cd4421e2e]]] [formerlyd7a5bab832[formerlyc77c5b48df] [formerly01ffd33e2f[formerlyaea728ceb6]] [formerly16afb18e35[formerly4768b156f3] [formerly3c1298c626[formerly1e61cf0974]]]]]] Former-commit-id:28b09a56cb[formerly5241dbb36c] [formerlyd43909d979[formerly43b0cca7f5]] [formerly8d8f384c8e[formerlybcf58203c6] [formerlyca56bff2d0[formerly7a8750ffc9]]] [formerly2ce9fa87ae[formerlye7b1b542f5] [formerly62a7edf94a[formerly26ca5f220d]] [formerly73d10253b9[formerlyaf463cecb0] [formerly961314f474[formerly6e7141e5e2]]]] [formerlyc3f8938ba5[formerly7f66748292] [formerly3e5ef2c136[formerly266ddb4ccc]] [formerly4a8a5437b3[formerlycbcb0f8777] [formerlyc212b8217f[formerly5f3d3d01c8]]] [formerly09764ba6cd[formerly4db991be5f] [formerlyf79c6ec15d[formerly7f8eb54d47]] [formerly771b00b188[formerlyf1dcba565f] [formerly83c42510ad[formerly3c1298c626]]]]] Former-commit-id:f5cdcca4f3[formerlydc57f947a2] [formerly55cd6eb9a3[formerlye92d6c0923]] [formerlyba80ed43d2[formerly0a2f65401a] [formerly73c0c2ebb3[formerlybf80cf285e]]] [formerlya34f21d933[formerlyae311cb3e7] [formerlyae0e3ed079[formerlyc9030d0303]] [formerlyf8f4f6a8ec[formerly343dd65df6] [formerly1c97e6a7ba[formerlyc69d281aca]]]] Former-commit-id:b93de57240[formerly6fd32d0759] [formerlyff29e71cb0[formerlyee483b56f9]] [formerly2f944ec28b[formerly15c46e806c] [formerlydca0c18b8b[formerlyf185a8658c]]] Former-commit-id:76ed7f184f[formerlyc8adbe1dea] [formerly392fc3e54b[formerly5bd396738e]] Former-commit-id:11579edbe1[formerly6970795314] Former-commit-id:8fca90698b
| @@ -0,0 +1,297 @@ | |||||
| import os | |||||
| import os.path | |||||
| import hashlib | |||||
| import gzip | |||||
| import errno | |||||
| import tarfile | |||||
| from typing import Any, Callable, List, Iterable, Optional, TypeVar | |||||
| import zipfile | |||||
| # import torch | |||||
| # from torch.utils.model_zoo import tqdm | |||||
| from hub import tqdm | |||||
| def gen_bar_updater() -> Callable[[int, int, int], None]: | |||||
| pbar = tqdm(total=None) | |||||
| def bar_update(count, block_size, total_size): | |||||
| if pbar.total is None and total_size: | |||||
| pbar.total = total_size | |||||
| progress_bytes = count * block_size | |||||
| pbar.update(progress_bytes - pbar.n) | |||||
| return bar_update | |||||
| def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: | |||||
| md5 = hashlib.md5() | |||||
| with open(fpath, 'rb') as f: | |||||
| for chunk in iter(lambda: f.read(chunk_size), b''): | |||||
| md5.update(chunk) | |||||
| return md5.hexdigest() | |||||
| def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: | |||||
| return md5 == calculate_md5(fpath, **kwargs) | |||||
| def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: | |||||
| if not os.path.isfile(fpath): | |||||
| return False | |||||
| if md5 is None: | |||||
| return True | |||||
| return check_md5(fpath, md5) | |||||
| def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None: | |||||
| """Download a file from a url and place it in root. | |||||
| Args: | |||||
| url (str): URL to download file from | |||||
| root (str): Directory to place downloaded file in | |||||
| filename (str, optional): Name to save the file under. If None, use the basename of the URL | |||||
| md5 (str, optional): MD5 checksum of the download. If None, do not check | |||||
| """ | |||||
| import urllib | |||||
| root = os.path.expanduser(root) | |||||
| if not filename: | |||||
| filename = os.path.basename(url) | |||||
| fpath = os.path.join(root, filename) | |||||
| os.makedirs(root, exist_ok=True) | |||||
| # check if file is already present locally | |||||
| if check_integrity(fpath, md5): | |||||
| print('Using downloaded and verified file: ' + fpath) | |||||
| else: # download the file | |||||
| try: | |||||
| print('Downloading ' + url + ' to ' + fpath) | |||||
| urllib.request.urlretrieve( | |||||
| url, fpath, | |||||
| reporthook=gen_bar_updater() | |||||
| ) | |||||
| except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] | |||||
| if url[:5] == 'https': | |||||
| url = url.replace('https:', 'http:') | |||||
| print('Failed download. Trying https -> http instead.' | |||||
| ' Downloading ' + url + ' to ' + fpath) | |||||
| urllib.request.urlretrieve( | |||||
| url, fpath, | |||||
| reporthook=gen_bar_updater() | |||||
| ) | |||||
| else: | |||||
| raise e | |||||
| # check integrity of downloaded file | |||||
| if not check_integrity(fpath, md5): | |||||
| raise RuntimeError("File not found or corrupted.") | |||||
| def list_dir(root: str, prefix: bool = False) -> List[str]: | |||||
| """List all directories at a given root | |||||
| Args: | |||||
| root (str): Path to directory whose folders need to be listed | |||||
| prefix (bool, optional): If true, prepends the path to each result, otherwise | |||||
| only returns the name of the directories found | |||||
| """ | |||||
| root = os.path.expanduser(root) | |||||
| directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))] | |||||
| if prefix is True: | |||||
| directories = [os.path.join(root, d) for d in directories] | |||||
| return directories | |||||
| def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: | |||||
| """List all files ending with a suffix at a given root | |||||
| Args: | |||||
| root (str): Path to directory whose folders need to be listed | |||||
| suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). | |||||
| It uses the Python "str.endswith" method and is passed directly | |||||
| prefix (bool, optional): If true, prepends the path to each result, otherwise | |||||
| only returns the name of the files found | |||||
| """ | |||||
| root = os.path.expanduser(root) | |||||
| files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)] | |||||
| if prefix is True: | |||||
| files = [os.path.join(root, d) for d in files] | |||||
| return files | |||||
| def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined] | |||||
| return "Google Drive - Quota exceeded" in response.text | |||||
| def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): | |||||
| """Download a Google Drive file from and place it in root. | |||||
| Args: | |||||
| file_id (str): id of file to be downloaded | |||||
| root (str): Directory to place downloaded file in | |||||
| filename (str, optional): Name to save the file under. If None, use the id of the file. | |||||
| md5 (str, optional): MD5 checksum of the download. If None, do not check | |||||
| """ | |||||
| # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url | |||||
| import requests | |||||
| url = "https://docs.google.com/uc?export=download" | |||||
| root = os.path.expanduser(root) | |||||
| if not filename: | |||||
| filename = file_id | |||||
| fpath = os.path.join(root, filename) | |||||
| os.makedirs(root, exist_ok=True) | |||||
| if os.path.isfile(fpath) and check_integrity(fpath, md5): | |||||
| print('Using downloaded and verified file: ' + fpath) | |||||
| else: | |||||
| session = requests.Session() | |||||
| response = session.get(url, params={'id': file_id}, stream=True) | |||||
| token = _get_confirm_token(response) | |||||
| if token: | |||||
| params = {'id': file_id, 'confirm': token} | |||||
| response = session.get(url, params=params, stream=True) | |||||
| if _quota_exceeded(response): | |||||
| msg = ( | |||||
| f"The daily quota of the file {filename} is exceeded and it " | |||||
| f"can't be downloaded. This is a limitation of Google Drive " | |||||
| f"and can only be overcome by trying again later." | |||||
| ) | |||||
| raise RuntimeError(msg) | |||||
| _save_response_content(response, fpath) | |||||
| def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined] | |||||
| for key, value in response.cookies.items(): | |||||
| if key.startswith('download_warning'): | |||||
| return value | |||||
| return None | |||||
| def _save_response_content( | |||||
| response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined] | |||||
| ) -> None: | |||||
| with open(destination, "wb") as f: | |||||
| pbar = tqdm(total=None) | |||||
| progress = 0 | |||||
| for chunk in response.iter_content(chunk_size): | |||||
| if chunk: # filter out keep-alive new chunks | |||||
| f.write(chunk) | |||||
| progress += len(chunk) | |||||
| pbar.update(progress - pbar.n) | |||||
| pbar.close() | |||||
| def _is_tarxz(filename: str) -> bool: | |||||
| return filename.endswith(".tar.xz") | |||||
| def _is_tar(filename: str) -> bool: | |||||
| return filename.endswith(".tar") | |||||
| def _is_targz(filename: str) -> bool: | |||||
| return filename.endswith(".tar.gz") | |||||
| def _is_tgz(filename: str) -> bool: | |||||
| return filename.endswith(".tgz") | |||||
| def _is_gzip(filename: str) -> bool: | |||||
| return filename.endswith(".gz") and not filename.endswith(".tar.gz") | |||||
| def _is_zip(filename: str) -> bool: | |||||
| return filename.endswith(".zip") | |||||
| def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None: | |||||
| if to_path is None: | |||||
| to_path = os.path.dirname(from_path) | |||||
| if _is_tar(from_path): | |||||
| with tarfile.open(from_path, 'r') as tar: | |||||
| tar.extractall(path=to_path) | |||||
| elif _is_targz(from_path) or _is_tgz(from_path): | |||||
| with tarfile.open(from_path, 'r:gz') as tar: | |||||
| tar.extractall(path=to_path) | |||||
| elif _is_tarxz(from_path): | |||||
| with tarfile.open(from_path, 'r:xz') as tar: | |||||
| tar.extractall(path=to_path) | |||||
| elif _is_gzip(from_path): | |||||
| to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) | |||||
| with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: | |||||
| out_f.write(zip_f.read()) | |||||
| elif _is_zip(from_path): | |||||
| with zipfile.ZipFile(from_path, 'r') as z: | |||||
| z.extractall(to_path) | |||||
| else: | |||||
| raise ValueError("Extraction of {} not supported".format(from_path)) | |||||
| if remove_finished: | |||||
| os.remove(from_path) | |||||
| def download_and_extract_archive( | |||||
| url: str, | |||||
| download_root: str, | |||||
| extract_root: Optional[str] = None, | |||||
| filename: Optional[str] = None, | |||||
| md5: Optional[str] = None, | |||||
| remove_finished: bool = False, | |||||
| ) -> None: | |||||
| download_root = os.path.expanduser(download_root) | |||||
| if extract_root is None: | |||||
| extract_root = download_root | |||||
| if not filename: | |||||
| filename = os.path.basename(url) | |||||
| download_url(url, download_root, filename, md5) | |||||
| archive = os.path.join(download_root, filename) | |||||
| print("Extracting {} to {}".format(archive, extract_root)) | |||||
| # print(archive) | |||||
| # print(extract_root) | |||||
| # extract_archive(archive, extract_root, remove_finished) | |||||
| def iterable_to_str(iterable: Iterable) -> str: | |||||
| return "'" + "', '".join([str(item) for item in iterable]) + "'" | |||||
| T = TypeVar("T", str, bytes) | |||||
| def verify_str_arg( | |||||
| value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None, | |||||
| ) -> T: | |||||
| if not isinstance(value, torch._six.string_classes): | |||||
| if arg is None: | |||||
| msg = "Expected type str, but got type {type}." | |||||
| else: | |||||
| msg = "Expected type str for argument {arg}, but got type {type}." | |||||
| msg = msg.format(type=type(value), arg=arg) | |||||
| raise ValueError(msg) | |||||
| if valid_values is None: | |||||
| return value | |||||
| if value not in valid_values: | |||||
| if custom_msg is not None: | |||||
| msg = custom_msg | |||||
| else: | |||||
| msg = ("Unknown value '{value}' for argument {arg}. " | |||||
| "Valid values are {{{valid_values}}}.") | |||||
| msg = msg.format(value=value, arg=arg, | |||||
| valid_values=iterable_to_str(valid_values)) | |||||
| raise ValueError(msg) | |||||
| return value | |||||
| @@ -0,0 +1,559 @@ | |||||
| import errno | |||||
| import hashlib | |||||
| import os | |||||
| import re | |||||
| import shutil | |||||
| import sys | |||||
| import tempfile | |||||
| # import torch | |||||
| import warnings | |||||
| import zipfile | |||||
| from urllib.request import urlopen, Request | |||||
| from urllib.parse import urlparse # noqa: F401 | |||||
| try: | |||||
| from tqdm.auto import tqdm # automatically select proper tqdm submodule if available | |||||
| except ImportError: | |||||
| try: | |||||
| from tqdm import tqdm | |||||
| except ImportError: | |||||
| # fake tqdm if it's not installed | |||||
| class tqdm(object): # type: ignore | |||||
| def __init__(self, total=None, disable=False, | |||||
| unit=None, unit_scale=None, unit_divisor=None): | |||||
| self.total = total | |||||
| self.disable = disable | |||||
| self.n = 0 | |||||
| # ignore unit, unit_scale, unit_divisor; they're just for real tqdm | |||||
| def update(self, n): | |||||
| if self.disable: | |||||
| return | |||||
| self.n += n | |||||
| if self.total is None: | |||||
| sys.stderr.write("\r{0:.1f} bytes".format(self.n)) | |||||
| else: | |||||
| sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) | |||||
| sys.stderr.flush() | |||||
| def __enter__(self): | |||||
| return self | |||||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||||
| if self.disable: | |||||
| return | |||||
| sys.stderr.write('\n') | |||||
| # # matches bfd8deac from resnet18-bfd8deac.pth | |||||
| # HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') | |||||
| # | |||||
| # MASTER_BRANCH = 'master' | |||||
| # ENV_TORCH_HOME = 'TORCH_HOME' | |||||
| # ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' | |||||
| # DEFAULT_CACHE_DIR = '~/.cache' | |||||
| # VAR_DEPENDENCY = 'dependencies' | |||||
| # MODULE_HUBCONF = 'hubconf.py' | |||||
| # READ_DATA_CHUNK = 8192 | |||||
| # _hub_dir = None | |||||
| # | |||||
| # | |||||
| # # Copied from tools/shared/module_loader to be included in torch package | |||||
| # def import_module(name, path): | |||||
| # import importlib.util | |||||
| # from importlib.abc import Loader | |||||
| # spec = importlib.util.spec_from_file_location(name, path) | |||||
| # module = importlib.util.module_from_spec(spec) | |||||
| # assert isinstance(spec.loader, Loader) | |||||
| # spec.loader.exec_module(module) | |||||
| # return module | |||||
| # | |||||
| # | |||||
| # def _remove_if_exists(path): | |||||
| # if os.path.exists(path): | |||||
| # if os.path.isfile(path): | |||||
| # os.remove(path) | |||||
| # else: | |||||
| # shutil.rmtree(path) | |||||
| # | |||||
| # | |||||
| # def _git_archive_link(repo_owner, repo_name, branch): | |||||
| # return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch) | |||||
| # | |||||
| # | |||||
| # def _load_attr_from_module(module, func_name): | |||||
| # # Check if callable is defined in the module | |||||
| # if func_name not in dir(module): | |||||
| # return None | |||||
| # return getattr(module, func_name) | |||||
| # | |||||
| # | |||||
| # def _get_torch_home(): | |||||
| # torch_home = os.path.expanduser( | |||||
| # os.getenv(ENV_TORCH_HOME, | |||||
| # os.path.join(os.getenv(ENV_XDG_CACHE_HOME, | |||||
| # DEFAULT_CACHE_DIR), 'torch'))) | |||||
| # return torch_home | |||||
| # | |||||
| # | |||||
| # def _parse_repo_info(github): | |||||
| # branch = MASTER_BRANCH | |||||
| # if ':' in github: | |||||
| # repo_info, branch = github.split(':') | |||||
| # else: | |||||
| # repo_info = github | |||||
| # repo_owner, repo_name = repo_info.split('/') | |||||
| # return repo_owner, repo_name, branch | |||||
| # | |||||
| # | |||||
| # def _get_cache_or_reload(github, force_reload, verbose=True): | |||||
| # # Setup hub_dir to save downloaded files | |||||
| # hub_dir = get_dir() | |||||
| # if not os.path.exists(hub_dir): | |||||
| # os.makedirs(hub_dir) | |||||
| # # Parse github repo information | |||||
| # repo_owner, repo_name, branch = _parse_repo_info(github) | |||||
| # # Github allows branch name with slash '/', | |||||
| # # this causes confusion with path on both Linux and Windows. | |||||
| # # Backslash is not allowed in Github branch name so no need to | |||||
| # # to worry about it. | |||||
| # normalized_br = branch.replace('/', '_') | |||||
| # # Github renames folder repo-v1.x.x to repo-1.x.x | |||||
| # # We don't know the repo name before downloading the zip file | |||||
| # # and inspect name from it. | |||||
| # # To check if cached repo exists, we need to normalize folder names. | |||||
| # repo_dir = os.path.join(hub_dir, '_'.join([repo_owner, repo_name, normalized_br])) | |||||
| # | |||||
| # use_cache = (not force_reload) and os.path.exists(repo_dir) | |||||
| # | |||||
| # if use_cache: | |||||
| # if verbose: | |||||
| # sys.stderr.write('Using cache found in {}\n'.format(repo_dir)) | |||||
| # else: | |||||
| # cached_file = os.path.join(hub_dir, normalized_br + '.zip') | |||||
| # _remove_if_exists(cached_file) | |||||
| # | |||||
| # url = _git_archive_link(repo_owner, repo_name, branch) | |||||
| # sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file)) | |||||
| # download_url_to_file(url, cached_file, progress=False) | |||||
| # | |||||
| # with zipfile.ZipFile(cached_file) as cached_zipfile: | |||||
| # extraced_repo_name = cached_zipfile.infolist()[0].filename | |||||
| # extracted_repo = os.path.join(hub_dir, extraced_repo_name) | |||||
| # _remove_if_exists(extracted_repo) | |||||
| # # Unzip the code and rename the base folder | |||||
| # cached_zipfile.extractall(hub_dir) | |||||
| # | |||||
| # _remove_if_exists(cached_file) | |||||
| # _remove_if_exists(repo_dir) | |||||
| # shutil.move(extracted_repo, repo_dir) # rename the repo | |||||
| # | |||||
| # return repo_dir | |||||
| # | |||||
| # | |||||
| # def _check_module_exists(name): | |||||
| # if sys.version_info >= (3, 4): | |||||
| # import importlib.util | |||||
| # return importlib.util.find_spec(name) is not None | |||||
| # elif sys.version_info >= (3, 3): | |||||
| # # Special case for python3.3 | |||||
| # import importlib.find_loader | |||||
| # return importlib.find_loader(name) is not None | |||||
| # else: | |||||
| # # NB: Python2.7 imp.find_module() doesn't respect PEP 302, | |||||
| # # it cannot find a package installed as .egg(zip) file. | |||||
| # # Here we use workaround from: | |||||
| # # https://stackoverflow.com/questions/28962344/imp-find-module-which-supports-zipped-eggs?lq=1 | |||||
| # # Also imp doesn't handle hierarchical module names (names contains dots). | |||||
| # try: | |||||
| # # 1. Try imp.find_module(), which searches sys.path, but does | |||||
| # # not respect PEP 302 import hooks. | |||||
| # import imp | |||||
| # result = imp.find_module(name) | |||||
| # if result: | |||||
| # return True | |||||
| # except ImportError: | |||||
| # pass | |||||
| # path = sys.path | |||||
| # for item in path: | |||||
| # # 2. Scan path for import hooks. sys.path_importer_cache maps | |||||
| # # path items to optional "importer" objects, that implement | |||||
| # # find_module() etc. Note that path must be a subset of | |||||
| # # sys.path for this to work. | |||||
| # importer = sys.path_importer_cache.get(item) | |||||
| # if importer: | |||||
| # try: | |||||
| # result = importer.find_module(name, [item]) | |||||
| # if result: | |||||
| # return True | |||||
| # except ImportError: | |||||
| # pass | |||||
| # return False | |||||
| # | |||||
| # def _check_dependencies(m): | |||||
| # dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) | |||||
| # | |||||
| # if dependencies is not None: | |||||
| # missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] | |||||
| # if len(missing_deps): | |||||
| # raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps))) | |||||
| # | |||||
| # | |||||
| # def _load_entry_from_hubconf(m, model): | |||||
| # if not isinstance(model, str): | |||||
| # raise ValueError('Invalid input: model should be a string of function name') | |||||
| # | |||||
| # # Note that if a missing dependency is imported at top level of hubconf, it will | |||||
| # # throw before this function. It's a chicken and egg situation where we have to | |||||
| # # load hubconf to know what're the dependencies, but to import hubconf it requires | |||||
| # # a missing package. This is fine, Python will throw proper error message for users. | |||||
| # _check_dependencies(m) | |||||
| # | |||||
| # func = _load_attr_from_module(m, model) | |||||
| # | |||||
| # if func is None or not callable(func): | |||||
| # raise RuntimeError('Cannot find callable {} in hubconf'.format(model)) | |||||
| # | |||||
| # return func | |||||
| # | |||||
| # | |||||
| # def get_dir(): | |||||
| # r""" | |||||
| # Get the Torch Hub cache directory used for storing downloaded models & weights. | |||||
| # | |||||
| # If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where | |||||
| # environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. | |||||
| # ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux | |||||
| # filesystem layout, with a default value ``~/.cache`` if the environment | |||||
| # variable is not set. | |||||
| # """ | |||||
| # # Issue warning to move data if old env is set | |||||
| # if os.getenv('TORCH_HUB'): | |||||
| # warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead') | |||||
| # | |||||
| # if _hub_dir is not None: | |||||
| # return _hub_dir | |||||
| # return os.path.join(_get_torch_home(), 'hub') | |||||
| # | |||||
| # | |||||
| # def set_dir(d): | |||||
| # r""" | |||||
| # Optionally set the Torch Hub directory used to save downloaded models & weights. | |||||
| # | |||||
| # Args: | |||||
| # d (string): path to a local folder to save downloaded models & weights. | |||||
| # """ | |||||
| # global _hub_dir | |||||
| # _hub_dir = d | |||||
| # | |||||
| # | |||||
| # def list(github, force_reload=False): | |||||
| # r""" | |||||
| # List all entrypoints available in `github` hubconf. | |||||
| # | |||||
| # Args: | |||||
| # github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional | |||||
| # tag/branch. The default branch is `master` if not specified. | |||||
| # Example: 'pytorch/vision[:hub]' | |||||
| # force_reload (bool, optional): whether to discard the existing cache and force a fresh download. | |||||
| # Default is `False`. | |||||
| # Returns: | |||||
| # entrypoints: a list of available entrypoint names | |||||
| # | |||||
| # Example: | |||||
| # >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) | |||||
| # """ | |||||
| # repo_dir = _get_cache_or_reload(github, force_reload, True) | |||||
| # | |||||
| # sys.path.insert(0, repo_dir) | |||||
| # | |||||
| # hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) | |||||
| # | |||||
| # sys.path.remove(repo_dir) | |||||
| # | |||||
| # # We take functions starts with '_' as internal helper functions | |||||
| # entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')] | |||||
| # | |||||
| # return entrypoints | |||||
| # | |||||
| # | |||||
| # def help(github, model, force_reload=False): | |||||
| # r""" | |||||
| # Show the docstring of entrypoint `model`. | |||||
| # | |||||
| # Args: | |||||
| # github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional | |||||
| # tag/branch. The default branch is `master` if not specified. | |||||
| # Example: 'pytorch/vision[:hub]' | |||||
| # model (string): a string of entrypoint name defined in repo's hubconf.py | |||||
| # force_reload (bool, optional): whether to discard the existing cache and force a fresh download. | |||||
| # Default is `False`. | |||||
| # Example: | |||||
| # >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) | |||||
| # """ | |||||
| # repo_dir = _get_cache_or_reload(github, force_reload, True) | |||||
| # | |||||
| # sys.path.insert(0, repo_dir) | |||||
| # | |||||
| # hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) | |||||
| # | |||||
| # sys.path.remove(repo_dir) | |||||
| # | |||||
| # entry = _load_entry_from_hubconf(hub_module, model) | |||||
| # | |||||
| # return entry.__doc__ | |||||
| # | |||||
| # | |||||
| # # Ideally this should be `def load(github, model, *args, forece_reload=False, **kwargs):`, | |||||
| # # but Python2 complains syntax error for it. We have to skip force_reload in function | |||||
| # # signature here but detect it in kwargs instead. | |||||
| # # TODO: fix it after Python2 EOL | |||||
| # def load(repo_or_dir, model, *args, **kwargs): | |||||
| # r""" | |||||
| # Load a model from a github repo or a local directory. | |||||
| # | |||||
| # Note: Loading a model is the typical use case, but this can also be used to | |||||
| # for loading other objects such as tokenizers, loss functions, etc. | |||||
| # | |||||
| # If :attr:`source` is ``'github'``, :attr:`repo_or_dir` is expected to be | |||||
| # of the form ``repo_owner/repo_name[:tag_name]`` with an optional | |||||
| # tag/branch. | |||||
| # | |||||
| # If :attr:`source` is ``'local'``, :attr:`repo_or_dir` is expected to be a | |||||
| # path to a local directory. | |||||
| # | |||||
| # Args: | |||||
| # repo_or_dir (string): repo name (``repo_owner/repo_name[:tag_name]``), | |||||
| # if ``source = 'github'``; or a path to a local directory, if | |||||
| # ``source = 'local'``. | |||||
| # model (string): the name of a callable (entrypoint) defined in the | |||||
| # repo/dir's ``hubconf.py``. | |||||
| # *args (optional): the corresponding args for callable :attr:`model`. | |||||
| # source (string, optional): ``'github'`` | ``'local'``. Specifies how | |||||
| # ``repo_or_dir`` is to be interpreted. Default is ``'github'``. | |||||
| # force_reload (bool, optional): whether to force a fresh download of | |||||
| # the github repo unconditionally. Does not have any effect if | |||||
| # ``source = 'local'``. Default is ``False``. | |||||
| # verbose (bool, optional): If ``False``, mute messages about hitting | |||||
| # local caches. Note that the message about first download cannot be | |||||
| # muted. Does not have any effect if ``source = 'local'``. | |||||
| # Default is ``True``. | |||||
| # **kwargs (optional): the corresponding kwargs for callable | |||||
| # :attr:`model`. | |||||
| # | |||||
| # Returns: | |||||
| # The output of the :attr:`model` callable when called with the given | |||||
| # ``*args`` and ``**kwargs``. | |||||
| # | |||||
| # Example: | |||||
| # >>> # from a github repo | |||||
| # >>> repo = 'pytorch/vision' | |||||
| # >>> model = torch.hub.load(repo, 'resnet50', pretrained=True) | |||||
| # >>> # from a local directory | |||||
| # >>> path = '/some/local/path/pytorch/vision' | |||||
| # >>> model = torch.hub.load(path, 'resnet50', pretrained=True) | |||||
| # """ | |||||
| # source = kwargs.pop('source', 'github').lower() | |||||
| # force_reload = kwargs.pop('force_reload', False) | |||||
| # verbose = kwargs.pop('verbose', True) | |||||
| # | |||||
| # if source not in ('github', 'local'): | |||||
| # raise ValueError( | |||||
| # f'Unknown source: "{source}". Allowed values: "github" | "local".') | |||||
| # | |||||
| # if source == 'github': | |||||
| # repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose) | |||||
| # | |||||
| # model = _load_local(repo_or_dir, model, *args, **kwargs) | |||||
| # return model | |||||
| # | |||||
| # | |||||
| # def _load_local(hubconf_dir, model, *args, **kwargs): | |||||
| # r""" | |||||
| # Load a model from a local directory with a ``hubconf.py``. | |||||
| # | |||||
| # Args: | |||||
| # hubconf_dir (string): path to a local directory that contains a | |||||
| # ``hubconf.py``. | |||||
| # model (string): name of an entrypoint defined in the directory's | |||||
| # `hubconf.py`. | |||||
| # *args (optional): the corresponding args for callable ``model``. | |||||
| # **kwargs (optional): the corresponding kwargs for callable ``model``. | |||||
| # | |||||
| # Returns: | |||||
| # a single model with corresponding pretrained weights. | |||||
| # | |||||
| # Example: | |||||
| # >>> path = '/some/local/path/pytorch/vision' | |||||
| # >>> model = _load_local(path, 'resnet50', pretrained=True) | |||||
| # """ | |||||
| # sys.path.insert(0, hubconf_dir) | |||||
| # | |||||
| # hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) | |||||
| # hub_module = import_module(MODULE_HUBCONF, hubconf_path) | |||||
| # | |||||
| # entry = _load_entry_from_hubconf(hub_module, model) | |||||
| # model = entry(*args, **kwargs) | |||||
| # | |||||
| # sys.path.remove(hubconf_dir) | |||||
| # | |||||
| # return model | |||||
| # | |||||
| # | |||||
| # def download_url_to_file(url, dst, hash_prefix=None, progress=True): | |||||
| # r"""Download object at the given URL to a local path. | |||||
| # | |||||
| # Args: | |||||
| # url (string): URL of the object to download | |||||
| # dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` | |||||
| # hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`. | |||||
| # Default: None | |||||
| # progress (bool, optional): whether or not to display a progress bar to stderr | |||||
| # Default: True | |||||
| # | |||||
| # Example: | |||||
| # >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') | |||||
| # | |||||
| # """ | |||||
| # file_size = None | |||||
| # # We use a different API for python2 since urllib(2) doesn't recognize the CA | |||||
| # # certificates in older Python | |||||
| # req = Request(url, headers={"User-Agent": "torch.hub"}) | |||||
| # u = urlopen(req) | |||||
| # meta = u.info() | |||||
| # if hasattr(meta, 'getheaders'): | |||||
| # content_length = meta.getheaders("Content-Length") | |||||
| # else: | |||||
| # content_length = meta.get_all("Content-Length") | |||||
| # if content_length is not None and len(content_length) > 0: | |||||
| # file_size = int(content_length[0]) | |||||
| # | |||||
| # # We deliberately save it in a temp file and move it after | |||||
| # # download is complete. This prevents a local working checkpoint | |||||
| # # being overridden by a broken download. | |||||
| # dst = os.path.expanduser(dst) | |||||
| # dst_dir = os.path.dirname(dst) | |||||
| # f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) | |||||
| # | |||||
| # try: | |||||
| # if hash_prefix is not None: | |||||
| # sha256 = hashlib.sha256() | |||||
| # with tqdm(total=file_size, disable=not progress, | |||||
| # unit='B', unit_scale=True, unit_divisor=1024) as pbar: | |||||
| # while True: | |||||
| # buffer = u.read(8192) | |||||
| # if len(buffer) == 0: | |||||
| # break | |||||
| # f.write(buffer) | |||||
| # if hash_prefix is not None: | |||||
| # sha256.update(buffer) | |||||
| # pbar.update(len(buffer)) | |||||
| # | |||||
| # f.close() | |||||
| # if hash_prefix is not None: | |||||
| # digest = sha256.hexdigest() | |||||
| # if digest[:len(hash_prefix)] != hash_prefix: | |||||
| # raise RuntimeError('invalid hash value (expected "{}", got "{}")' | |||||
| # .format(hash_prefix, digest)) | |||||
| # shutil.move(f.name, dst) | |||||
| # finally: | |||||
| # f.close() | |||||
| # if os.path.exists(f.name): | |||||
| # os.remove(f.name) | |||||
| # | |||||
| # def _download_url_to_file(url, dst, hash_prefix=None, progress=True): | |||||
| # warnings.warn('torch.hub._download_url_to_file has been renamed to\ | |||||
| # torch.hub.download_url_to_file to be a public API,\ | |||||
| # _download_url_to_file will be removed in after 1.3 release') | |||||
| # download_url_to_file(url, dst, hash_prefix, progress) | |||||
| # | |||||
| # # Hub used to support automatically extracts from zipfile manually compressed by users. | |||||
| # # The legacy zip format expects only one file from torch.save() < 1.6 in the zip. | |||||
| # # We should remove this support since zipfile is now default zipfile format for torch.save(). | |||||
| # def _is_legacy_zip_format(filename): | |||||
| # if zipfile.is_zipfile(filename): | |||||
| # infolist = zipfile.ZipFile(filename).infolist() | |||||
| # return len(infolist) == 1 and not infolist[0].is_dir() | |||||
| # return False | |||||
| # | |||||
| # def _legacy_zip_load(filename, model_dir, map_location): | |||||
| # warnings.warn('Falling back to the old format < 1.6. This support will be ' | |||||
| # 'deprecated in favor of default zipfile format introduced in 1.6. ' | |||||
| # 'Please redo torch.save() to save it in the new zipfile format.') | |||||
| # # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. | |||||
| # # We deliberately don't handle tarfile here since our legacy serialization format was in tar. | |||||
| # # E.g. resnet18-5c106cde.pth which is widely used. | |||||
| # with zipfile.ZipFile(filename) as f: | |||||
| # members = f.infolist() | |||||
| # if len(members) != 1: | |||||
| # raise RuntimeError('Only one file(not dir) is allowed in the zipfile') | |||||
| # f.extractall(model_dir) | |||||
| # extraced_name = members[0].filename | |||||
| # extracted_file = os.path.join(model_dir, extraced_name) | |||||
| # return torch.load(extracted_file, map_location=map_location) | |||||
| # | |||||
| # def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): | |||||
| # r"""Loads the Torch serialized object at the given URL. | |||||
| # | |||||
| # If downloaded file is a zip file, it will be automatically | |||||
| # decompressed. | |||||
| # | |||||
| # If the object is already present in `model_dir`, it's deserialized and | |||||
| # returned. | |||||
| # The default value of `model_dir` is ``<hub_dir>/checkpoints`` where | |||||
| # `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. | |||||
| # | |||||
| # Args: | |||||
| # url (string): URL of the object to download | |||||
| # model_dir (string, optional): directory in which to save the object | |||||
| # map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) | |||||
| # progress (bool, optional): whether or not to display a progress bar to stderr. | |||||
| # Default: True | |||||
| # check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention | |||||
| # ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more | |||||
| # digits of the SHA256 hash of the contents of the file. The hash is used to | |||||
| # ensure unique names and to verify the contents of the file. | |||||
| # Default: False | |||||
| # file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set. | |||||
| # | |||||
| # Example: | |||||
| # >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') | |||||
| # | |||||
| # """ | |||||
| # # Issue warning to move data if old env is set | |||||
| # if os.getenv('TORCH_MODEL_ZOO'): | |||||
| # warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') | |||||
| # | |||||
| # if model_dir is None: | |||||
| # hub_dir = get_dir() | |||||
| # model_dir = os.path.join(hub_dir, 'checkpoints') | |||||
| # | |||||
| # try: | |||||
| # os.makedirs(model_dir) | |||||
| # except OSError as e: | |||||
| # if e.errno == errno.EEXIST: | |||||
| # # Directory already exists, ignore. | |||||
| # pass | |||||
| # else: | |||||
| # # Unexpected OSError, re-raise. | |||||
| # raise | |||||
| # | |||||
| # parts = urlparse(url) | |||||
| # filename = os.path.basename(parts.path) | |||||
| # if file_name is not None: | |||||
| # filename = file_name | |||||
| # cached_file = os.path.join(model_dir, filename) | |||||
| # if not os.path.exists(cached_file): | |||||
| # sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |||||
| # hash_prefix = None | |||||
| # if check_hash: | |||||
| # r = HASH_REGEX.search(filename) # r is Optional[Match[str]] | |||||
| # hash_prefix = r.group(1) if r else None | |||||
| # download_url_to_file(url, cached_file, hash_prefix, progress=progress) | |||||
| # | |||||
| # if _is_legacy_zip_format(cached_file): | |||||
| # return _legacy_zip_load(cached_file, model_dir, map_location) | |||||
| # return torch.load(cached_file, map_location=map_location) | |||||
| @@ -0,0 +1,139 @@ | |||||
| import warnings | |||||
| import os | |||||
| import os.path | |||||
| import numpy as np | |||||
| import codecs | |||||
| import string | |||||
| import gzip | |||||
| import lzma | |||||
| from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union | |||||
| from dataset_utils import download_url, download_and_extract_archive, extract_archive, verify_str_arg | |||||
| # tqdm >= 4.31.1 | |||||
| from tods import generate_dataset | |||||
| from sklearn import preprocessing | |||||
| import pandas as pd | |||||
| class TODS_dataset: | |||||
| resources = [] | |||||
| training_file = None | |||||
| testing_file = None | |||||
| ground_truth_index = None | |||||
| _repr_indent = None | |||||
| @property | |||||
| def raw_folder(self) -> str: | |||||
| return os.path.join(self.root, self.__class__.__name__, 'raw') | |||||
| @property | |||||
| def processed_folder(self) -> str: | |||||
| return os.path.join(self.root, self.__class__.__name__, 'processed') | |||||
| def __init__(self, root, train, transform=None, download=True): | |||||
| self.root = root | |||||
| self.train = train | |||||
| self.transform = self.transform_init(transform) | |||||
| if download: | |||||
| self.download() | |||||
| pass | |||||
| self.process() | |||||
| def _check_exists(self) -> bool: | |||||
| return (os.path.exists(os.path.join(self.processed_folder, | |||||
| self.training_file)) and | |||||
| os.path.exists(os.path.join(self.processed_folder, | |||||
| self.testing_file))) | |||||
| def download(self) -> None: | |||||
| if self._check_exists(): | |||||
| return | |||||
| os.makedirs(self.raw_folder, exist_ok=True) | |||||
| # download files | |||||
| for url, md5 in self.resources: | |||||
| filename = url.rpartition('/')[2] | |||||
| download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) | |||||
| def process(self) -> None: | |||||
| pass | |||||
| def process_dataframe(self) -> None: | |||||
| if self.transform is None: | |||||
| pass | |||||
| else: | |||||
| self.transform.fit(self.training_set_dataframe) | |||||
| self.training_set_array = self.transform.transform(self.training_set_dataframe.values) | |||||
| self.testing_set_array = self.transform.transform(self.testing_set_dataframe.values) | |||||
| self.training_set_dataframe = pd.DataFrame(self.training_set_array) | |||||
| self.testing_set_dataframe = pd.DataFrame(self.testing_set_array) | |||||
| def transform_init(self, transform_str): | |||||
| if transform_str is None: | |||||
| return None | |||||
| elif transform_str == 'standardscale': | |||||
| return preprocessing.StandardScaler() | |||||
| elif transform_str == 'normalize': | |||||
| return preprocessing.Normalizer() | |||||
| elif transform_str == 'minmaxscale': | |||||
| return preprocessing.MinMaxScaler() | |||||
| elif transform_str == 'maxabsscale': | |||||
| return preprocessing.MaxAbsScaler() | |||||
| elif transform_str == 'binarize': | |||||
| return preprocessing.Binarizer() | |||||
| else: | |||||
| raise ValueError("Input parameter transform must take value of 'standardscale', 'normalize', " + | |||||
| "'minmaxscale', 'maxabsscale' or 'binarize'." | |||||
| ) | |||||
| def to_axolotl_dataset(self): | |||||
| if self.train: | |||||
| return generate_dataset(self.training_set_dataframe, self.ground_truth_index) | |||||
| else: | |||||
| return generate_dataset(self.testing_set_dataframe, self.ground_truth_index) | |||||
| def __repr__(self) -> str: | |||||
| head = "Dataset " + self.__class__.__name__ | |||||
| body = ["Number of datapoints: {}".format(self.__len__())] | |||||
| if self.root is not None: | |||||
| body.append("Root location: {}".format(self.root)) | |||||
| body += self.extra_repr().splitlines() | |||||
| if hasattr(self, "transforms") and self.transforms is not None: | |||||
| body += [repr(self.transforms)] | |||||
| lines = [head] + [" " * self._repr_indent + line for line in body] | |||||
| print(self.training_set_dataframe) | |||||
| return '\n'.join(lines) | |||||
| def __len__(self) -> int: | |||||
| return len(self.training_set_dataframe) | |||||
| def extra_repr(self) -> str: | |||||
| return "" | |||||
| # kpi(root='./datasets', train=True) | |||||
| # class yahoo5: | |||||
| # | |||||
| # def __init__(self): | |||||
| # pass | |||||
| @@ -0,0 +1,116 @@ | |||||
| import os | |||||
| import pandas as pd | |||||
| from tods_dataset_base import TODS_dataset | |||||
| from shutil import copyfile | |||||
| class kpi_dataset(TODS_dataset): | |||||
| resources = [ | |||||
| # ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), | |||||
| # ("https://github.com/datamllab/tods/blob/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), | |||||
| # ("https://github.com/NetManAIOps/KPI-Anomaly-Detection/blob/master/Preliminary_dataset/train.csv", None), | |||||
| ("https://hegsns.github.io/tods_datasets/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||||
| ("https://hegsns.github.io/tods_datasets/kpi/TRAIN/dataset_TRAIN/datasetDoc.json", None), | |||||
| # needs a server to store the dataset. | |||||
| # ("https://raw.githubusercontent.com/datamllab/tods/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||||
| ] | |||||
| training_file = 'learningData.csv' | |||||
| testing_file = 'testingData.csv' | |||||
| ground_truth_index = 3 | |||||
| _repr_indent = 4 | |||||
| # def __init__(self, root, train, transform=None, target_transform=None, download=True): | |||||
| # super().__init__(root, train, transform=None, target_transform=None, download=True) | |||||
| def process(self) -> None: | |||||
| print('Processing...') | |||||
| os.makedirs(self.processed_folder, exist_ok=True) | |||||
| os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) | |||||
| training_set_fname = os.path.join(self.raw_folder, 'learningData.csv') | |||||
| self.training_set_dataframe = pd.read_csv(training_set_fname) | |||||
| testing_set_fname = os.path.join(self.raw_folder, 'learningData.csv') # temperarily same with training set | |||||
| self.testing_set_dataframe = pd.read_csv(testing_set_fname) | |||||
| self.process_dataframe() | |||||
| self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) | |||||
| self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) | |||||
| copyfile(os.path.join(self.raw_folder, 'datasetDoc.json'), os.path.join(self.processed_folder, 'datasetDoc.json')) | |||||
| print('Done!') | |||||
| class yahoo_dataset(TODS_dataset): | |||||
| resources = [ | |||||
| # ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), | |||||
| # ("https://github.com/datamllab/tods/blob/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), | |||||
| # ("https://github.com/NetManAIOps/KPI-Anomaly-Detection/blob/master/Preliminary_dataset/train.csv", None), | |||||
| ("https://hegsns.github.io/tods_datasets/yahoo_sub_5/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||||
| ("https://hegsns.github.io/tods_datasets/yahoo_sub_5/TRAIN/dataset_TRAIN/datasetDoc.json", None), | |||||
| # needs a server to store the dataset. | |||||
| # ("https://raw.githubusercontent.com/datamllab/tods/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||||
| ] | |||||
| training_file = 'learningData.csv' | |||||
| testing_file = 'testingData.csv' | |||||
| ground_truth_index = 7 | |||||
| _repr_indent = 4 | |||||
| def process(self) -> None: | |||||
| print('Processing...') | |||||
| os.makedirs(self.processed_folder, exist_ok=True) | |||||
| os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) | |||||
| training_set_fname = os.path.join(self.raw_folder, 'learningData.csv') | |||||
| self.training_set_dataframe = pd.read_csv(training_set_fname) | |||||
| testing_set_fname = os.path.join(self.raw_folder, 'learningData.csv') # temperarily same with training set | |||||
| self.testing_set_dataframe = pd.read_csv(testing_set_fname) | |||||
| self.process_dataframe() | |||||
| self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) | |||||
| self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) | |||||
| copyfile(os.path.join(self.raw_folder, 'datasetDoc.json'), os.path.join(self.processed_folder, 'datasetDoc.json')) | |||||
| print('Done!') | |||||
| class NAB_dataset(TODS_dataset): | |||||
| resources = [ | |||||
| ("https://hegsns.github.io/tods_datasets/NAB/realTweets/labeled_Twitter_volume_AMZN.csv", None), | |||||
| # it needs md5 to check if local learningData.csv is the same with online. | |||||
| ("https://hegsns.github.io/tods_datasets/NAB/realTweets/labeled_Twitter_volume_AMZN.json", None), | |||||
| # needs a server to store the dataset. | |||||
| ] | |||||
| training_file = 'learningData.csv' | |||||
| testing_file = 'testingData.csv' | |||||
| ground_truth_index = 2 | |||||
| _repr_indent = 4 | |||||
| def process(self) -> None: | |||||
| print('Processing...') | |||||
| os.makedirs(self.processed_folder, exist_ok=True) | |||||
| os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) | |||||
| training_set_fname = os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.csv') | |||||
| self.training_set_dataframe = pd.read_csv(training_set_fname) | |||||
| testing_set_fname = os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.csv') # temperarily same with training set | |||||
| self.testing_set_dataframe = pd.read_csv(testing_set_fname) | |||||
| self.process_dataframe() | |||||
| self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) | |||||
| self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) | |||||
| copyfile(os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.json'), | |||||
| os.path.join(self.processed_folder, 'datasetDoc.json')) | |||||
| print('Done!') | |||||
| # kpi_dataset(root='./datasets', train=True, transform='binarize') | |||||
| # yahoo_dataset(root='./datasets', train=True, transform='binarize') | |||||
| # NAB_dataset(root='./datasets', train=True, transform='binarize') | |||||
| @@ -1 +0,0 @@ | |||||
| Subproject commit af54e6970476a081bf0cd65990c9f56a1200d8a2 | |||||
| @@ -1 +0,0 @@ | |||||
| Subproject commit 046b20d2f6d4543dcbe18f0a1d4bcbb1f61cf518 | |||||
| @@ -1 +0,0 @@ | |||||
| Subproject commit 70aeefed6b7307941581357c4b7858bb3f88e1da | |||||
| @@ -0,0 +1,116 @@ | |||||
| import os | |||||
| import typing | |||||
| import numpy | |||||
| import pandas | |||||
| from d3m import container, exceptions, utils as d3m_utils | |||||
| from d3m.metadata import base as metadata_base, hyperparams | |||||
| from d3m.base import primitives | |||||
| __all__ = ('FixedSplitDatasetSplitPrimitive',) | |||||
| class Hyperparams(hyperparams.Hyperparams): | |||||
| primary_index_values = hyperparams.Set( | |||||
| elements=hyperparams.Hyperparameter[str](''), | |||||
| default=(), | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description='A set of primary index values of the main resource belonging to the test (score) split. Cannot be set together with "row_indices".', | |||||
| ) | |||||
| row_indices = hyperparams.Set( | |||||
| elements=hyperparams.Hyperparameter[int](-1), | |||||
| default=(), | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description='A set of row indices of the main resource belonging to the test (score) split. Cannot be set together with "primary_index_values".', | |||||
| ) | |||||
| delete_recursive = hyperparams.Hyperparameter[bool]( | |||||
| default=False, | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", | |||||
| ) | |||||
| class FixedSplitDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||||
| """ | |||||
| A primitive which splits a tabular Dataset in a way that uses for the test | |||||
| (score) split a fixed list of primary index values or row indices of the main | |||||
| resource to be used. All other rows are added used for the train split. | |||||
| """ | |||||
| metadata = metadata_base.PrimitiveMetadata( | |||||
| { | |||||
| 'id': '1654f000-2178-4520-be4c-a95bc26b8d3a', | |||||
| 'version': '0.1.0', | |||||
| 'name': "Fixed split tabular dataset splits", | |||||
| 'python_path': 'd3m.primitives.tods.evaluation.fixed_split_dataset_split', | |||||
| 'source': { | |||||
| 'name': "DATALab@TexasA&M University", | |||||
| 'contact': 'mailto:mitar.commonprimitives@tnode.com', | |||||
| 'uris': [ | |||||
| 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/fixed_split.py', | |||||
| 'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||||
| ], | |||||
| }, | |||||
| 'algorithm_types': [ | |||||
| metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||||
| ], | |||||
| 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||||
| }, | |||||
| ) | |||||
| def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||||
| # This should be handled by "Set" hyper-parameter, but we check it here again just to be sure. | |||||
| if d3m_utils.has_duplicates(self.hyperparams['primary_index_values']): | |||||
| raise exceptions.InvalidArgumentValueError("\"primary_index_values\" hyper-parameter has duplicate values.") | |||||
| if d3m_utils.has_duplicates(self.hyperparams['row_indices']): | |||||
| raise exceptions.InvalidArgumentValueError("\"row_indices\" hyper-parameter has duplicate values.") | |||||
| if self.hyperparams['primary_index_values'] and self.hyperparams['row_indices']: | |||||
| raise exceptions.InvalidArgumentValueError("Both \"primary_index_values\" and \"row_indices\" cannot be provided.") | |||||
| if self.hyperparams['primary_index_values']: | |||||
| primary_index_values = numpy.array(self.hyperparams['primary_index_values']) | |||||
| index_columns = dataset.metadata.get_index_columns(at=(main_resource_id,)) | |||||
| if not index_columns: | |||||
| raise exceptions.InvalidArgumentValueError("Cannot find index columns in the main resource of the dataset, but \"primary_index_values\" is provided.") | |||||
| main_resource = dataset[main_resource_id] | |||||
| # We reset the index so that the index corresponds to row indices. | |||||
| main_resource = main_resource.reset_index(drop=True) | |||||
| # We use just the "d3mIndex" column and ignore multi-key indices. | |||||
| # This works for now because it seems that every current multi-key | |||||
| # dataset in fact has an unique value in "d3mIndex" alone. | |||||
| # See: https://gitlab.datadrivendiscovery.org/MIT-LL/d3m_data_supply/issues/117 | |||||
| index_column = index_columns[0] | |||||
| score_data = numpy.array(main_resource.loc[main_resource.iloc[:, index_column].isin(primary_index_values)].index) | |||||
| score_data_set = set(score_data) | |||||
| assert len(score_data) == len(score_data_set), (len(score_data), len(score_data_set)) | |||||
| if len(score_data) != len(primary_index_values): | |||||
| raise exceptions.InvalidArgumentValueError("\"primary_index_values\" contains values which do not exist.") | |||||
| else: | |||||
| score_data = numpy.array(self.hyperparams['row_indices']) | |||||
| score_data_set = set(score_data) | |||||
| all_data_set = set(numpy.arange(len(attributes))) | |||||
| if not score_data_set <= all_data_set: | |||||
| raise exceptions.InvalidArgumentValueError("\"row_indices\" contains indices which do not exist, e.g., {indices}.".format( | |||||
| indices=sorted(score_data_set - all_data_set)[:5], | |||||
| )) | |||||
| train_data = [] | |||||
| for i in numpy.arange(len(attributes)): | |||||
| if i not in score_data_set: | |||||
| train_data.append(i) | |||||
| assert len(train_data) + len(score_data) == len(attributes), (len(train_data), len(score_data), len(attributes)) | |||||
| return [(numpy.array(train_data), score_data)] | |||||
| @@ -0,0 +1,87 @@ | |||||
| import os | |||||
| import typing | |||||
| import numpy | |||||
| import pandas | |||||
| from sklearn import model_selection | |||||
| from d3m import container, exceptions, utils as d3m_utils | |||||
| from d3m.metadata import base as metadata_base, hyperparams | |||||
| from d3m.base import primitives | |||||
| __all__ = ('KFoldDatasetSplitPrimitive',) | |||||
| class Hyperparams(hyperparams.Hyperparams): | |||||
| number_of_folds = hyperparams.Bounded[int]( | |||||
| lower=2, | |||||
| upper=None, | |||||
| default=5, | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description="Number of folds for k-folds cross-validation.", | |||||
| ) | |||||
| stratified = hyperparams.UniformBool( | |||||
| default=False, | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description="Do stratified folds. The folds are made by preserving the percentage of samples for each class.", | |||||
| ) | |||||
| shuffle = hyperparams.UniformBool( | |||||
| default=False, | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description="Whether to shuffle the data before splitting into batches.", | |||||
| ) | |||||
| delete_recursive = hyperparams.Hyperparameter[bool]( | |||||
| default=False, | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", | |||||
| ) | |||||
| class KFoldDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||||
| """ | |||||
| A primitive which splits a tabular Dataset for k-fold cross-validation. | |||||
| """ | |||||
| __author__ = 'Mingjie Sun <sunmj15@gmail.com>' | |||||
| metadata = metadata_base.PrimitiveMetadata( | |||||
| { | |||||
| 'id': 'bfedaf3a-6dd0-4a83-ad83-3a50fe882bf8', | |||||
| 'version': '0.1.0', | |||||
| 'name': "K-fold cross-validation tabular dataset splits", | |||||
| 'python_path': 'd3m.primitives.tods.evaluation.kfold_dataset_split', | |||||
| 'source': { | |||||
| 'name': 'DATALab@Texas A&M University', | |||||
| 'contact': 'mailto:sunmj15@gmail.com', | |||||
| 'uris': [ | |||||
| 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/kfold_split.py', | |||||
| 'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||||
| ], | |||||
| }, | |||||
| 'algorithm_types': [ | |||||
| metadata_base.PrimitiveAlgorithmType.K_FOLD, | |||||
| metadata_base.PrimitiveAlgorithmType.CROSS_VALIDATION, | |||||
| metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||||
| ], | |||||
| 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||||
| }, | |||||
| ) | |||||
| def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||||
| if self.hyperparams['stratified']: | |||||
| if not len(targets.columns): | |||||
| raise exceptions.InvalidArgumentValueError("Stratified split is requested, but no target columns found.") | |||||
| k_fold = model_selection.StratifiedKFold( | |||||
| n_splits=self.hyperparams['number_of_folds'], | |||||
| shuffle=self.hyperparams['shuffle'], | |||||
| random_state=self._random_state, | |||||
| ) | |||||
| else: | |||||
| k_fold = model_selection.KFold( | |||||
| n_splits=self.hyperparams['number_of_folds'], | |||||
| shuffle=self.hyperparams['shuffle'], | |||||
| random_state=self._random_state, | |||||
| ) | |||||
| return list(k_fold.split(attributes, targets)) | |||||
| @@ -0,0 +1,187 @@ | |||||
| import os | |||||
| import uuid | |||||
| import typing | |||||
| from collections import OrderedDict | |||||
| import numpy | |||||
| import pandas | |||||
| from sklearn import model_selection | |||||
| from d3m import container, exceptions, utils as d3m_utils | |||||
| from d3m.metadata import base as metadata_base, hyperparams | |||||
| from d3m.base import primitives | |||||
| import utils | |||||
| __all__ = ('KFoldTimeSeriesSplitPrimitive',) | |||||
| class Hyperparams(hyperparams.Hyperparams): | |||||
| number_of_folds = hyperparams.Bounded[int]( | |||||
| lower=2, | |||||
| upper=None, | |||||
| default=5, | |||||
| semantic_types=[ | |||||
| 'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||||
| ], | |||||
| description="Number of folds for k-folds cross-validation.", | |||||
| ) | |||||
| number_of_window_folds = hyperparams.Union[typing.Union[int, None]]( | |||||
| configuration=OrderedDict( | |||||
| fixed=hyperparams.Bounded[int]( | |||||
| lower=1, | |||||
| upper=None, | |||||
| default=1, | |||||
| description="Number of folds in train set (window). These folds come directly " | |||||
| "before test set (streaming window).", | |||||
| ), | |||||
| all_records=hyperparams.Constant( | |||||
| default=None, | |||||
| description="Number of folds in train set (window) = maximum number possible.", | |||||
| ), | |||||
| ), | |||||
| default='all_records', | |||||
| semantic_types=[ | |||||
| 'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||||
| ], | |||||
| description="Maximum size for a single training set.", | |||||
| ) | |||||
| time_column_index = hyperparams.Union[typing.Union[int, None]]( | |||||
| configuration=OrderedDict( | |||||
| fixed=hyperparams.Bounded[int]( | |||||
| lower=1, | |||||
| upper=None, | |||||
| default=1, | |||||
| description="Specific column that contains the time index", | |||||
| ), | |||||
| one_column=hyperparams.Constant( | |||||
| default=None, | |||||
| description="Only one column contains a time index. " | |||||
| "It is detected automatically using semantic types.", | |||||
| ), | |||||
| ), | |||||
| default='one_column', | |||||
| semantic_types=[ | |||||
| 'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||||
| ], | |||||
| description="Column index to use as datetime index. " | |||||
| "If None, it is required that only one column with time column role semantic type is " | |||||
| "present and otherwise an exception is raised. " | |||||
| "If column index specified is not a datetime column an exception is" | |||||
| "also raised.", | |||||
| ) | |||||
| fuzzy_time_parsing = hyperparams.UniformBool( | |||||
| default=True, | |||||
| semantic_types=[ | |||||
| 'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||||
| ], | |||||
| description="Use fuzzy time parsing.", | |||||
| ) | |||||
| class KFoldTimeSeriesSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||||
| """ | |||||
| A primitive which splits a tabular time-series Dataset for k-fold cross-validation. | |||||
| Primitive sorts the time column so care should be taken to assure sorting of a | |||||
| column is reasonable. E.g., if column is not numeric but of string structural type, | |||||
| strings should be formatted so that sorting by them also sorts by time. | |||||
| """ | |||||
| __author__ = 'Distil' | |||||
| __version__ = '0.3.0' | |||||
| __contact__ = 'mailto:jeffrey.gleason@yonder.co' | |||||
| metadata = metadata_base.PrimitiveMetadata( | |||||
| { | |||||
| 'id': '002f9ad1-46e3-40f4-89ed-eeffbb3a102b', | |||||
| 'version': __version__, | |||||
| 'name': "K-fold cross-validation timeseries dataset splits", | |||||
| 'python_path': 'd3m.primitives.tods.evaluation.kfold_time_series_split', | |||||
| 'source': { | |||||
| 'name': 'DATALab@Texas A&M University', | |||||
| 'contact': __contact__, | |||||
| 'uris': [ | |||||
| 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/kfold_split_timeseries.py', | |||||
| 'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||||
| ], | |||||
| }, | |||||
| 'algorithm_types': [ | |||||
| metadata_base.PrimitiveAlgorithmType.K_FOLD, | |||||
| metadata_base.PrimitiveAlgorithmType.CROSS_VALIDATION, | |||||
| metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||||
| ], | |||||
| 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||||
| }, | |||||
| ) | |||||
| def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||||
| time_column_indices = dataset.metadata.list_columns_with_semantic_types(['https://metadata.datadrivendiscovery.org/types/Time'], at=(main_resource_id,)) | |||||
| attribute_column_indices = dataset.metadata.list_columns_with_semantic_types(['https://metadata.datadrivendiscovery.org/types/Attribute'], at=(main_resource_id,)) | |||||
| # We want only time columns which are also attributes. | |||||
| time_column_indices = [time_column_index for time_column_index in time_column_indices if time_column_index in attribute_column_indices] | |||||
| if self.hyperparams['time_column_index'] is None: | |||||
| if len(time_column_indices) != 1: | |||||
| raise exceptions.InvalidArgumentValueError( | |||||
| "If \"time_column_index\" hyper-parameter is \"None\", it is required that exactly one column with time column role semantic type is present.", | |||||
| ) | |||||
| else: | |||||
| # We know it exists because "time_column_indices" is a subset of "attribute_column_indices". | |||||
| time_column_index = attribute_column_indices.index( | |||||
| time_column_indices[0], | |||||
| ) | |||||
| else: | |||||
| if self.hyperparams['time_column_index'] not in time_column_indices: | |||||
| raise exceptions.InvalidArgumentValueError( | |||||
| "Time column index specified does not have a time column role semantic type.", | |||||
| ) | |||||
| else: | |||||
| time_column_index = attribute_column_indices.index( | |||||
| self.hyperparams['time_column_index'], | |||||
| ) | |||||
| # We first reset index. | |||||
| attributes = attributes.reset_index(drop=True) | |||||
| # Then convert datetime column to consistent datetime representation | |||||
| attributes.insert( | |||||
| loc=0, | |||||
| column=uuid.uuid4(), # use uuid to ensure we are inserting a new column name | |||||
| value=self._parse_time_data( | |||||
| attributes, time_column_index, self.hyperparams['fuzzy_time_parsing'], | |||||
| ), | |||||
| ) | |||||
| # Then sort dataframe by new datetime column. Index contains original row order. | |||||
| attributes = attributes.sort_values(by=attributes.columns[0]) | |||||
| # Remove datetime representation used for sorting (primitives might choose to parse this str col differently). | |||||
| attributes = attributes.drop(attributes.columns[0], axis=1) | |||||
| max_train_size: typing.Optional[int] = None | |||||
| if self.hyperparams['number_of_window_folds'] is not None: | |||||
| max_train_size = int(attributes.shape[0] * self.hyperparams['number_of_window_folds'] / self.hyperparams['number_of_folds']) | |||||
| k_fold = model_selection.TimeSeriesSplit( | |||||
| n_splits=self.hyperparams['number_of_folds'], | |||||
| max_train_size=max_train_size | |||||
| ) | |||||
| # We sorted "attributes" so we have to map indices on sorted "attributes" back to original | |||||
| # indices. We do that by using DataFrame's index which contains original row order. | |||||
| return [ | |||||
| ( | |||||
| numpy.array([attributes.index[val] for val in train]), | |||||
| numpy.array([attributes.index[val] for val in test]), | |||||
| ) | |||||
| for train, test in k_fold.split(attributes) | |||||
| ] | |||||
| @classmethod | |||||
| def _parse_time_data(cls, inputs: container.DataFrame, column_index: metadata_base.SimpleSelectorSegment, fuzzy: bool) -> typing.List[float]: | |||||
| return [ | |||||
| utils.parse_datetime_to_float(value, fuzzy=fuzzy) | |||||
| for value in inputs.iloc[:, column_index] | |||||
| ] | |||||
| @@ -0,0 +1,52 @@ | |||||
| import os | |||||
| import typing | |||||
| import numpy | |||||
| import pandas | |||||
| from d3m import container, utils as d3m_utils | |||||
| from d3m.metadata import base as metadata_base, hyperparams | |||||
| from d3m.base import primitives | |||||
| __all__ = ('NoSplitDatasetSplitPrimitive',) | |||||
| class Hyperparams(hyperparams.Hyperparams): | |||||
| pass | |||||
| class NoSplitDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||||
| """ | |||||
| A primitive which splits a tabular Dataset in a way that for all splits it | |||||
| produces the same (full) Dataset. Useful for unsupervised learning tasks. . | |||||
| """ | |||||
| metadata = metadata_base.PrimitiveMetadata( | |||||
| { | |||||
| 'id': '48c683ad-da9e-48cf-b3a0-7394dba5e5d2', | |||||
| 'version': '0.1.0', | |||||
| 'name': "No-split tabular dataset splits", | |||||
| 'python_path': 'd3m.primitives.tods.evaluation.no_split_dataset_split', | |||||
| 'source': { | |||||
| 'name': 'DATALab@Texas A&M University', | |||||
| 'contact': 'mailto:mitar.commonprimitives@tnode.com', | |||||
| 'uris': [ | |||||
| 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/no_split.py', | |||||
| 'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||||
| ], | |||||
| }, | |||||
| 'algorithm_types': [ | |||||
| metadata_base.PrimitiveAlgorithmType.IDENTITY_FUNCTION, | |||||
| metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||||
| ], | |||||
| 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||||
| }, | |||||
| ) | |||||
| def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||||
| # We still go through the whole splitting process to assure full compatibility | |||||
| # (and error conditions) of a regular split, but we use all data for both splits. | |||||
| all_data = numpy.arange(len(attributes)) | |||||
| return [(all_data, all_data)] | |||||
| @@ -0,0 +1,160 @@ | |||||
| import copy | |||||
| import os | |||||
| import typing | |||||
| from d3m import container, exceptions, utils as d3m_utils | |||||
| from d3m.metadata import base as metadata_base, hyperparams | |||||
| from d3m.primitive_interfaces import base, transformer | |||||
| Inputs = container.List | |||||
| Outputs = container.List | |||||
| class Hyperparams(hyperparams.Hyperparams): | |||||
| match_logic = hyperparams.Enumeration( | |||||
| values=['all', 'any'], | |||||
| default='any', | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description="Should a column have all of semantic types in \"semantic_types\" to be redacted, or any of them?", | |||||
| ) | |||||
| semantic_types = hyperparams.Set( | |||||
| elements=hyperparams.Hyperparameter[str](''), | |||||
| default=(), | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description="Redact columns with these semantic types. Only columns having semantic types listed here will be operated on, based on \"match_logic\".", | |||||
| ) | |||||
| add_semantic_types = hyperparams.Set( | |||||
| elements=hyperparams.Hyperparameter[str](''), | |||||
| default=(), | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description="Semantic types to add to redacted columns. All listed semantic types will be added to all columns which were redacted.", | |||||
| ) | |||||
| # TODO: Make clear the assumption that both container type (List) and Datasets should have metadata. | |||||
| # Primitive is modifying metadata of Datasets, while there is officially no reason for them | |||||
| # to really have metadata: metadata is stored available on the input container type, not | |||||
| # values inside it. | |||||
| class RedactColumnsPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]): | |||||
| """ | |||||
| A primitive which takes as an input a list of ``Dataset`` objects and redacts values of all columns matching | |||||
| a given semantic type or types. | |||||
| Redaction is done by setting all values in a redacted column to an empty string. | |||||
| It operates only on DataFrame resources inside datasets. | |||||
| """ | |||||
| metadata = metadata_base.PrimitiveMetadata( | |||||
| { | |||||
| 'id': '744c4090-e2f6-489e-8efc-8b1e051bfad6', | |||||
| 'version': '0.2.0', | |||||
| 'name': "Redact columns for evaluation", | |||||
| 'python_path': 'd3m.primitives.tods.evaluation.redact_columns', | |||||
| 'source': { | |||||
| 'name': 'DATALab@Texas A&M University', | |||||
| 'contact': 'mailto:sunmj15@gmail.com', | |||||
| 'uris': [ | |||||
| 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/redact_columns.py', | |||||
| 'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||||
| ], | |||||
| }, | |||||
| 'installation': [{ | |||||
| 'type': metadata_base.PrimitiveInstallationType.PIP, | |||||
| 'package_uri': 'git+https://gitlab.com/datadrivendiscovery/common-primitives.git@{git_commit}#egg=common_primitives'.format( | |||||
| git_commit=d3m_utils.current_git_commit(os.path.dirname(__file__)), | |||||
| ), | |||||
| }], | |||||
| 'algorithm_types': [ | |||||
| metadata_base.PrimitiveAlgorithmType.DATA_CONVERSION, | |||||
| ], | |||||
| 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||||
| }, | |||||
| ) | |||||
| def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]: | |||||
| output_datasets = container.List(generate_metadata=True) | |||||
| for dataset in inputs: | |||||
| resources = {} | |||||
| metadata = dataset.metadata | |||||
| for resource_id, resource in dataset.items(): | |||||
| if not isinstance(resource, container.DataFrame): | |||||
| resources[resource_id] = resource | |||||
| continue | |||||
| columns_to_redact = self._get_columns_to_redact(metadata, (resource_id,)) | |||||
| if not columns_to_redact: | |||||
| resources[resource_id] = resource | |||||
| continue | |||||
| resource = copy.copy(resource) | |||||
| for column_index in columns_to_redact: | |||||
| column_metadata = dataset.metadata.query((resource_id, metadata_base.ALL_ELEMENTS, column_index)) | |||||
| if 'structural_type' in column_metadata and issubclass(column_metadata['structural_type'], str): | |||||
| resource.iloc[:, column_index] = '' | |||||
| else: | |||||
| raise TypeError("Primitive can operate only on columns with structural type \"str\", not \"{type}\".".format( | |||||
| type=column_metadata.get('structural_type', None), | |||||
| )) | |||||
| metadata = self._update_metadata(metadata, resource_id, column_index, ()) | |||||
| resources[resource_id] = resource | |||||
| dataset = container.Dataset(resources, metadata) | |||||
| output_datasets.append(dataset) | |||||
| output_datasets.metadata = metadata_base.DataMetadata({ | |||||
| 'schema': metadata_base.CONTAINER_SCHEMA_VERSION, | |||||
| 'structural_type': container.List, | |||||
| 'dimension': { | |||||
| 'length': len(output_datasets), | |||||
| }, | |||||
| }) | |||||
| # We update metadata based on metadata of each dataset. | |||||
| # TODO: In the future this might be done automatically by generate_metadata. | |||||
| # See: https://gitlab.com/datadrivendiscovery/d3m/issues/119 | |||||
| for index, dataset in enumerate(output_datasets): | |||||
| output_datasets.metadata = dataset.metadata.copy_to(output_datasets.metadata, (), (index,)) | |||||
| return base.CallResult(output_datasets) | |||||
| def _get_columns_to_redact(self, inputs_metadata: metadata_base.DataMetadata, at: metadata_base.Selector) -> typing.Sequence[int]: | |||||
| columns = [] | |||||
| for element in inputs_metadata.get_elements(list(at) + [metadata_base.ALL_ELEMENTS]): | |||||
| semantic_types = inputs_metadata.query(list(at) + [metadata_base.ALL_ELEMENTS, element]).get('semantic_types', ()) | |||||
| # TODO: Should we handle inheritance between semantic types here? | |||||
| if self.hyperparams['match_logic'] == 'all': | |||||
| matched = all(semantic_type in semantic_types for semantic_type in self.hyperparams['semantic_types']) | |||||
| elif self.hyperparams['match_logic'] == 'any': | |||||
| matched = any(semantic_type in semantic_types for semantic_type in self.hyperparams['semantic_types']) | |||||
| else: | |||||
| raise exceptions.UnexpectedValueError("Unknown value of hyper-parameter \"match_logic\": {value}".format(value=self.hyperparams['match_logic'])) | |||||
| if matched: | |||||
| if element is metadata_base.ALL_ELEMENTS: | |||||
| return list(range(inputs_metadata.query(list(at) + [metadata_base.ALL_ELEMENTS]).get('dimension', {}).get('length', 0))) | |||||
| else: | |||||
| columns.append(typing.cast(int, element)) | |||||
| return columns | |||||
| def _update_metadata( | |||||
| self, inputs_metadata: metadata_base.DataMetadata, resource_id: metadata_base.SelectorSegment, | |||||
| column_index: int, at: metadata_base.Selector, | |||||
| ) -> metadata_base.DataMetadata: | |||||
| outputs_metadata = inputs_metadata | |||||
| for semantic_type in self.hyperparams['add_semantic_types']: | |||||
| outputs_metadata = outputs_metadata.add_semantic_type(tuple(at) + (resource_id, metadata_base.ALL_ELEMENTS, column_index), semantic_type) | |||||
| return outputs_metadata | |||||
| @@ -0,0 +1,88 @@ | |||||
| import os | |||||
| import typing | |||||
| import numpy | |||||
| import pandas | |||||
| from sklearn import model_selection | |||||
| from d3m import container, exceptions, utils as d3m_utils | |||||
| from d3m.metadata import base as metadata_base, hyperparams | |||||
| from d3m.base import primitives | |||||
| __all__ = ('TrainScoreDatasetSplitPrimitive',) | |||||
| class Hyperparams(hyperparams.Hyperparams): | |||||
| train_score_ratio = hyperparams.Uniform( | |||||
| lower=0, | |||||
| upper=1, | |||||
| default=0.75, | |||||
| upper_inclusive=True, | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description="The ratio between the train and score data and represents the proportion of the Dataset to include in the train split. The rest is included in the score split.", | |||||
| ) | |||||
| stratified = hyperparams.UniformBool( | |||||
| default=False, | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description="Do stratified folds. The folds are made by preserving the percentage of samples for each class.", | |||||
| ) | |||||
| shuffle = hyperparams.UniformBool( | |||||
| default=False, | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description="Whether to shuffle the data before splitting into batches.", | |||||
| ) | |||||
| delete_recursive = hyperparams.Hyperparameter[bool]( | |||||
| default=False, | |||||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||||
| description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", | |||||
| ) | |||||
| class TrainScoreDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||||
| """ | |||||
| A primitive which splits a tabular Dataset into random train and score subsets. | |||||
| """ | |||||
| metadata = metadata_base.PrimitiveMetadata( | |||||
| { | |||||
| 'id': '3fcc6dc4-6681-4c86-948e-066d14e7d803', | |||||
| 'version': '0.1.0', | |||||
| 'name': "Train-score tabular dataset splits", | |||||
| 'python_path': 'd3m.primitives.tods.evaluation.train_score_dataset_split', | |||||
| 'source': { | |||||
| 'name': 'DATALab@Texas A&M University', | |||||
| 'contact': 'mailto:mitar.commonprimitives@tnode.com', | |||||
| 'uris': [ | |||||
| 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/train_score_split.py', | |||||
| 'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||||
| ], | |||||
| }, | |||||
| 'installation': [{ | |||||
| 'type': metadata_base.PrimitiveInstallationType.PIP, | |||||
| 'package_uri': 'git+https://gitlab.com/datadrivendiscovery/common-primitives.git@{git_commit}#egg=common_primitives'.format( | |||||
| git_commit=d3m_utils.current_git_commit(os.path.dirname(__file__)), | |||||
| ), | |||||
| }], | |||||
| 'algorithm_types': [ | |||||
| metadata_base.PrimitiveAlgorithmType.HOLDOUT, | |||||
| metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||||
| ], | |||||
| 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||||
| }, | |||||
| ) | |||||
| def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||||
| if self.hyperparams['stratified'] and not len(targets.columns): | |||||
| raise exceptions.InvalidArgumentValueError("Stratified split is requested, but no target columns found.") | |||||
| train_data, score_data = model_selection.train_test_split( | |||||
| numpy.arange(len(attributes)), | |||||
| test_size=None, | |||||
| train_size=self.hyperparams['train_score_ratio'], | |||||
| random_state=self._random_state, | |||||
| shuffle=self.hyperparams['shuffle'], | |||||
| stratify=targets if self.hyperparams['stratified'] else None, | |||||
| ) | |||||
| return [(train_data, score_data)] | |||||
| @@ -0,0 +1,192 @@ | |||||
| import datetime | |||||
| import logging | |||||
| import typing | |||||
| import dateutil.parser | |||||
| import numpy | |||||
| from d3m import container, deprecate | |||||
| from d3m.base import utils as base_utils | |||||
| from d3m.metadata import base as metadata_base | |||||
| logger = logging.getLogger(__name__) | |||||
| DEFAULT_DATETIME = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) | |||||
| @deprecate.function(message="it should not be used anymore") | |||||
| def copy_elements_metadata(source_metadata: metadata_base.Metadata, target_metadata: metadata_base.DataMetadata, from_selector: metadata_base.Selector, | |||||
| to_selector: metadata_base.Selector = (), *, ignore_all_elements: bool = False, check: bool = True, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
| return source_metadata._copy_elements_metadata(target_metadata, list(from_selector), list(to_selector), [], ignore_all_elements) | |||||
| @deprecate.function(message="use Metadata.copy_to method instead") | |||||
| def copy_metadata(source_metadata: metadata_base.Metadata, target_metadata: metadata_base.DataMetadata, from_selector: metadata_base.Selector, | |||||
| to_selector: metadata_base.Selector = (), *, ignore_all_elements: bool = False, check: bool = True, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
| return source_metadata.copy_to(target_metadata, from_selector, to_selector, ignore_all_elements=ignore_all_elements) | |||||
| @deprecate.function(message="use DataFrame.select_columns method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def select_columns(inputs: container.DataFrame, columns: typing.Sequence[metadata_base.SimpleSelectorSegment], *, | |||||
| source: typing.Any = None) -> container.DataFrame: | |||||
| return inputs.select_columns(columns) | |||||
| @deprecate.function(message="use DataMetadata.select_columns method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def select_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns: typing.Sequence[metadata_base.SimpleSelectorSegment], *, | |||||
| source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
| return inputs_metadata.select_columns(columns) | |||||
| @deprecate.function(message="use DataMetadata.list_columns_with_semantic_types method instead") | |||||
| def list_columns_with_semantic_types(metadata: metadata_base.DataMetadata, semantic_types: typing.Sequence[str], *, | |||||
| at: metadata_base.Selector = ()) -> typing.Sequence[int]: | |||||
| return metadata.list_columns_with_semantic_types(semantic_types, at=at) | |||||
| @deprecate.function(message="use DataMetadata.list_columns_with_structural_types method instead") | |||||
| def list_columns_with_structural_types(metadata: metadata_base.DataMetadata, structural_types: typing.Union[typing.Callable, typing.Sequence[typing.Union[str, type]]], *, | |||||
| at: metadata_base.Selector = ()) -> typing.Sequence[int]: | |||||
| return metadata.list_columns_with_structural_types(structural_types, at=at) | |||||
| @deprecate.function(message="use DataFrame.remove_columns method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def remove_columns(inputs: container.DataFrame, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> container.DataFrame: | |||||
| return inputs.remove_columns(column_indices) | |||||
| @deprecate.function(message="use DataMetadata.remove_columns method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def remove_columns_metadata(inputs_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
| return inputs_metadata.remove_columns(column_indices) | |||||
| @deprecate.function(message="use DataFrame.append_columns method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def append_columns(left: container.DataFrame, right: container.DataFrame, *, use_right_metadata: bool = False, source: typing.Any = None) -> container.DataFrame: | |||||
| return left.append_columns(right, use_right_metadata=use_right_metadata) | |||||
| @deprecate.function(message="use DataMetadata.append_columns method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def append_columns_metadata(left_metadata: metadata_base.DataMetadata, right_metadata: metadata_base.DataMetadata, use_right_metadata: bool = False, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
| return left_metadata.append_columns(right_metadata, use_right_metadata=use_right_metadata) | |||||
| @deprecate.function(message="use DataFrame.insert_columns method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def insert_columns(inputs: container.DataFrame, columns: container.DataFrame, at_column_index: int, *, source: typing.Any = None) -> container.DataFrame: | |||||
| return inputs.insert_columns(columns, at_column_index) | |||||
| @deprecate.function(message="use DataMetadata.insert_columns method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def insert_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns_metadata: metadata_base.DataMetadata, at_column_index: int, *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
| return inputs_metadata.insert_columns(columns_metadata, at_column_index) | |||||
| @deprecate.function(message="use DataFrame.replace_columns method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def replace_columns(inputs: container.DataFrame, columns: container.DataFrame, column_indices: typing.Sequence[int], *, copy: bool = True, source: typing.Any = None) -> container.DataFrame: | |||||
| return inputs.replace_columns(columns, column_indices, copy=copy) | |||||
| @deprecate.function(message="use DataMetadata.replace_columns method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def replace_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
| return inputs_metadata.replace_columns(columns_metadata, column_indices) | |||||
| @deprecate.function(message="use DataMetadata.get_index_columns method instead") | |||||
| def get_index_columns(metadata: metadata_base.DataMetadata, *, at: metadata_base.Selector = ()) -> typing.Sequence[int]: | |||||
| return metadata.get_index_columns(at=at) | |||||
| @deprecate.function(message="use DataFrame.horizontal_concat method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def horizontal_concat(left: container.DataFrame, right: container.DataFrame, *, use_index: bool = True, | |||||
| remove_second_index: bool = True, use_right_metadata: bool = False, source: typing.Any = None) -> container.DataFrame: | |||||
| return left.horizontal_concat(right, use_index=use_index, remove_second_index=remove_second_index, use_right_metadata=use_right_metadata) | |||||
| @deprecate.function(message="use DataMetadata.horizontal_concat method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def horizontal_concat_metadata(left_metadata: metadata_base.DataMetadata, right_metadata: metadata_base.DataMetadata, *, use_index: bool = True, | |||||
| remove_second_index: bool = True, use_right_metadata: bool = False, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
| return left_metadata.horizontal_concat(right_metadata, use_index=use_index, remove_second_index=remove_second_index, use_right_metadata=use_right_metadata) | |||||
| @deprecate.function(message="use d3m.base.utils.get_columns_to_use function instead") | |||||
| def get_columns_to_use(metadata: metadata_base.DataMetadata, use_columns: typing.Sequence[int], exclude_columns: typing.Sequence[int], | |||||
| can_use_column: typing.Callable) -> typing.Tuple[typing.List[int], typing.List[int]]: | |||||
| return base_utils.get_columns_to_use(metadata, use_columns, exclude_columns, can_use_column) | |||||
| @deprecate.function(message="use d3m.base.utils.combine_columns function instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def combine_columns(return_result: str, add_index_columns: bool, inputs: container.DataFrame, column_indices: typing.Sequence[int], | |||||
| columns_list: typing.Sequence[container.DataFrame], *, source: typing.Any = None) -> container.DataFrame: | |||||
| return base_utils.combine_columns(inputs, column_indices, columns_list, return_result=return_result, add_index_columns=add_index_columns) | |||||
| @deprecate.function(message="use d3m.base.utils.combine_columns_metadata function instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def combine_columns_metadata(return_result: str, add_index_columns: bool, inputs_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], | |||||
| columns_metadata_list: typing.Sequence[metadata_base.DataMetadata], *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
| return base_utils.combine_columns_metadata(inputs_metadata, column_indices, columns_metadata_list, return_result=return_result, add_index_columns=add_index_columns) | |||||
| @deprecate.function(message="use DataMetadata.set_table_metadata method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def set_table_metadata(inputs_metadata: metadata_base.DataMetadata, *, at: metadata_base.Selector = (), source: typing.Any = None) -> metadata_base.DataMetadata: | |||||
| return inputs_metadata.set_table_metadata(at=at) | |||||
| @deprecate.function(message="use DataMetadata.get_column_index_from_column_name method instead") | |||||
| def get_column_index_from_column_name(inputs_metadata: metadata_base.DataMetadata, column_name: str, *, at: metadata_base.Selector = ()) -> int: | |||||
| return inputs_metadata.get_column_index_from_column_name(column_name, at=at) | |||||
| @deprecate.function(message="use Dataset.get_relations_graph method instead") | |||||
| def build_relation_graph(dataset: container.Dataset) -> typing.Dict[str, typing.List[typing.Tuple[str, bool, int, int, typing.Dict]]]: | |||||
| return dataset.get_relations_graph() | |||||
| @deprecate.function(message="use d3m.base.utils.get_tabular_resource function instead") | |||||
| def get_tabular_resource(dataset: container.Dataset, resource_id: typing.Optional[str], *, | |||||
| pick_entry_point: bool = True, pick_one: bool = True, has_hyperparameter: bool = True) -> typing.Tuple[str, container.DataFrame]: | |||||
| return base_utils.get_tabular_resource(dataset, resource_id, pick_entry_point=pick_entry_point, pick_one=pick_one, has_hyperparameter=has_hyperparameter) | |||||
| @deprecate.function(message="use d3m.base.utils.get_tabular_resource_metadata function instead") | |||||
| def get_tabular_resource_metadata(dataset_metadata: metadata_base.DataMetadata, resource_id: typing.Optional[metadata_base.SelectorSegment], *, | |||||
| pick_entry_point: bool = True, pick_one: bool = True) -> metadata_base.SelectorSegment: | |||||
| return base_utils.get_tabular_resource_metadata(dataset_metadata, resource_id, pick_entry_point=pick_entry_point, pick_one=pick_one) | |||||
| @deprecate.function(message="use Dataset.select_rows method instead") | |||||
| @deprecate.arguments('source', message="argument ignored") | |||||
| def cut_dataset(dataset: container.Dataset, row_indices_to_keep: typing.Mapping[str, typing.Sequence[int]], *, | |||||
| source: typing.Any = None) -> container.Dataset: | |||||
| return dataset.select_rows(row_indices_to_keep) | |||||
| def parse_datetime(value: str, *, fuzzy: bool = True) -> typing.Optional[datetime.datetime]: | |||||
| try: | |||||
| return dateutil.parser.parse(value, default=DEFAULT_DATETIME, fuzzy=fuzzy) | |||||
| except (ValueError, OverflowError, TypeError): | |||||
| return None | |||||
| def parse_datetime_to_float(value: str, *, fuzzy: bool = True) -> float: | |||||
| try: | |||||
| parsed = parse_datetime(value, fuzzy=fuzzy) | |||||
| if parsed is None: | |||||
| return numpy.nan | |||||
| else: | |||||
| return parsed.timestamp() | |||||
| except (ValueError, OverflowError, TypeError): | |||||
| return numpy.nan | |||||