Former-commit-id:masterb0d35d9aea[formerly28b7582d25] [formerly632f8af13a[formerly9eb3583256]] [formerlyab6f5a8a9c[formerly2720e70c78] [formerlyc1f615a3c5[formerly21840b3787]]] [formerly2184b60ad6[formerly2dae6c4ed4] [formerly6a577865b8[formerlyee28aaed84]] [formerly72078d88c8[formerlyab1fb9d512] [formerly476600a47c[formerlyf5a6205bb5]]]] [formerlye5c2c3deef[formerly8585dc41a9] [formerly9cc7fe2088[formerly1d2104316f]] [formerlyb51e614a11[formerlyf9891a191d] [formerly048aa2f114[formerlyd4de64574b]]] [formerlyff66f55a51[formerly0b691b9a7f] [formerlye64e8dd253[formerlycc45939e01]] [formerlycdad10712a[formerly2789e20e79] [formerly25924f293c[formerly32997accab]]]]] [formerlyf43b431040[formerly95815d02ca] [formerlyfe9bd45d44[formerly6daf0aa73e]] [formerly61ab30c9a3[formerlya13c6e23b4] [formerly86fa5919ee[formerly1e49e1a303]]] [formerly69c5bc967a[formerlyab82915cbd] [formerlyf8057c3b14[formerly5232f34578]] [formerly671c54e952[formerly6454a28f26] [formerly3db6ff66b9[formerlyaa8c7fe127]]]] [formerly86b2ec6b84[formerlyf35c344efe] [formerlyd5616f66cd[formerly98c9dca7da]] [formerlya7dcc62bc5[formerly4ef4fa0c98] [formerly55f670b9ae[formerly1cd4421e2e]]] [formerlyd7a5bab832[formerlyc77c5b48df] [formerly01ffd33e2f[formerlyaea728ceb6]] [formerly16afb18e35[formerly4768b156f3] [formerly3c1298c626[formerly1e61cf0974]]]]]] Former-commit-id:28b09a56cb[formerly5241dbb36c] [formerlyd43909d979[formerly43b0cca7f5]] [formerly8d8f384c8e[formerlybcf58203c6] [formerlyca56bff2d0[formerly7a8750ffc9]]] [formerly2ce9fa87ae[formerlye7b1b542f5] [formerly62a7edf94a[formerly26ca5f220d]] [formerly73d10253b9[formerlyaf463cecb0] [formerly961314f474[formerly6e7141e5e2]]]] [formerlyc3f8938ba5[formerly7f66748292] [formerly3e5ef2c136[formerly266ddb4ccc]] [formerly4a8a5437b3[formerlycbcb0f8777] [formerlyc212b8217f[formerly5f3d3d01c8]]] [formerly09764ba6cd[formerly4db991be5f] [formerlyf79c6ec15d[formerly7f8eb54d47]] [formerly771b00b188[formerlyf1dcba565f] [formerly83c42510ad[formerly3c1298c626]]]]] Former-commit-id:f5cdcca4f3[formerlydc57f947a2] [formerly55cd6eb9a3[formerlye92d6c0923]] [formerlyba80ed43d2[formerly0a2f65401a] [formerly73c0c2ebb3[formerlybf80cf285e]]] [formerlya34f21d933[formerlyae311cb3e7] [formerlyae0e3ed079[formerlyc9030d0303]] [formerlyf8f4f6a8ec[formerly343dd65df6] [formerly1c97e6a7ba[formerlyc69d281aca]]]] Former-commit-id:b93de57240[formerly6fd32d0759] [formerlyff29e71cb0[formerlyee483b56f9]] [formerly2f944ec28b[formerly15c46e806c] [formerlydca0c18b8b[formerlyf185a8658c]]] Former-commit-id:76ed7f184f[formerlyc8adbe1dea] [formerly392fc3e54b[formerly5bd396738e]] Former-commit-id:11579edbe1[formerly6970795314] Former-commit-id:8fca90698b
| @@ -0,0 +1,297 @@ | |||
| import os | |||
| import os.path | |||
| import hashlib | |||
| import gzip | |||
| import errno | |||
| import tarfile | |||
| from typing import Any, Callable, List, Iterable, Optional, TypeVar | |||
| import zipfile | |||
| # import torch | |||
| # from torch.utils.model_zoo import tqdm | |||
| from hub import tqdm | |||
| def gen_bar_updater() -> Callable[[int, int, int], None]: | |||
| pbar = tqdm(total=None) | |||
| def bar_update(count, block_size, total_size): | |||
| if pbar.total is None and total_size: | |||
| pbar.total = total_size | |||
| progress_bytes = count * block_size | |||
| pbar.update(progress_bytes - pbar.n) | |||
| return bar_update | |||
| def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: | |||
| md5 = hashlib.md5() | |||
| with open(fpath, 'rb') as f: | |||
| for chunk in iter(lambda: f.read(chunk_size), b''): | |||
| md5.update(chunk) | |||
| return md5.hexdigest() | |||
| def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: | |||
| return md5 == calculate_md5(fpath, **kwargs) | |||
| def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: | |||
| if not os.path.isfile(fpath): | |||
| return False | |||
| if md5 is None: | |||
| return True | |||
| return check_md5(fpath, md5) | |||
| def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None: | |||
| """Download a file from a url and place it in root. | |||
| Args: | |||
| url (str): URL to download file from | |||
| root (str): Directory to place downloaded file in | |||
| filename (str, optional): Name to save the file under. If None, use the basename of the URL | |||
| md5 (str, optional): MD5 checksum of the download. If None, do not check | |||
| """ | |||
| import urllib | |||
| root = os.path.expanduser(root) | |||
| if not filename: | |||
| filename = os.path.basename(url) | |||
| fpath = os.path.join(root, filename) | |||
| os.makedirs(root, exist_ok=True) | |||
| # check if file is already present locally | |||
| if check_integrity(fpath, md5): | |||
| print('Using downloaded and verified file: ' + fpath) | |||
| else: # download the file | |||
| try: | |||
| print('Downloading ' + url + ' to ' + fpath) | |||
| urllib.request.urlretrieve( | |||
| url, fpath, | |||
| reporthook=gen_bar_updater() | |||
| ) | |||
| except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] | |||
| if url[:5] == 'https': | |||
| url = url.replace('https:', 'http:') | |||
| print('Failed download. Trying https -> http instead.' | |||
| ' Downloading ' + url + ' to ' + fpath) | |||
| urllib.request.urlretrieve( | |||
| url, fpath, | |||
| reporthook=gen_bar_updater() | |||
| ) | |||
| else: | |||
| raise e | |||
| # check integrity of downloaded file | |||
| if not check_integrity(fpath, md5): | |||
| raise RuntimeError("File not found or corrupted.") | |||
| def list_dir(root: str, prefix: bool = False) -> List[str]: | |||
| """List all directories at a given root | |||
| Args: | |||
| root (str): Path to directory whose folders need to be listed | |||
| prefix (bool, optional): If true, prepends the path to each result, otherwise | |||
| only returns the name of the directories found | |||
| """ | |||
| root = os.path.expanduser(root) | |||
| directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))] | |||
| if prefix is True: | |||
| directories = [os.path.join(root, d) for d in directories] | |||
| return directories | |||
| def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: | |||
| """List all files ending with a suffix at a given root | |||
| Args: | |||
| root (str): Path to directory whose folders need to be listed | |||
| suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). | |||
| It uses the Python "str.endswith" method and is passed directly | |||
| prefix (bool, optional): If true, prepends the path to each result, otherwise | |||
| only returns the name of the files found | |||
| """ | |||
| root = os.path.expanduser(root) | |||
| files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)] | |||
| if prefix is True: | |||
| files = [os.path.join(root, d) for d in files] | |||
| return files | |||
| def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined] | |||
| return "Google Drive - Quota exceeded" in response.text | |||
| def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): | |||
| """Download a Google Drive file from and place it in root. | |||
| Args: | |||
| file_id (str): id of file to be downloaded | |||
| root (str): Directory to place downloaded file in | |||
| filename (str, optional): Name to save the file under. If None, use the id of the file. | |||
| md5 (str, optional): MD5 checksum of the download. If None, do not check | |||
| """ | |||
| # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url | |||
| import requests | |||
| url = "https://docs.google.com/uc?export=download" | |||
| root = os.path.expanduser(root) | |||
| if not filename: | |||
| filename = file_id | |||
| fpath = os.path.join(root, filename) | |||
| os.makedirs(root, exist_ok=True) | |||
| if os.path.isfile(fpath) and check_integrity(fpath, md5): | |||
| print('Using downloaded and verified file: ' + fpath) | |||
| else: | |||
| session = requests.Session() | |||
| response = session.get(url, params={'id': file_id}, stream=True) | |||
| token = _get_confirm_token(response) | |||
| if token: | |||
| params = {'id': file_id, 'confirm': token} | |||
| response = session.get(url, params=params, stream=True) | |||
| if _quota_exceeded(response): | |||
| msg = ( | |||
| f"The daily quota of the file {filename} is exceeded and it " | |||
| f"can't be downloaded. This is a limitation of Google Drive " | |||
| f"and can only be overcome by trying again later." | |||
| ) | |||
| raise RuntimeError(msg) | |||
| _save_response_content(response, fpath) | |||
| def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined] | |||
| for key, value in response.cookies.items(): | |||
| if key.startswith('download_warning'): | |||
| return value | |||
| return None | |||
| def _save_response_content( | |||
| response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined] | |||
| ) -> None: | |||
| with open(destination, "wb") as f: | |||
| pbar = tqdm(total=None) | |||
| progress = 0 | |||
| for chunk in response.iter_content(chunk_size): | |||
| if chunk: # filter out keep-alive new chunks | |||
| f.write(chunk) | |||
| progress += len(chunk) | |||
| pbar.update(progress - pbar.n) | |||
| pbar.close() | |||
| def _is_tarxz(filename: str) -> bool: | |||
| return filename.endswith(".tar.xz") | |||
| def _is_tar(filename: str) -> bool: | |||
| return filename.endswith(".tar") | |||
| def _is_targz(filename: str) -> bool: | |||
| return filename.endswith(".tar.gz") | |||
| def _is_tgz(filename: str) -> bool: | |||
| return filename.endswith(".tgz") | |||
| def _is_gzip(filename: str) -> bool: | |||
| return filename.endswith(".gz") and not filename.endswith(".tar.gz") | |||
| def _is_zip(filename: str) -> bool: | |||
| return filename.endswith(".zip") | |||
| def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None: | |||
| if to_path is None: | |||
| to_path = os.path.dirname(from_path) | |||
| if _is_tar(from_path): | |||
| with tarfile.open(from_path, 'r') as tar: | |||
| tar.extractall(path=to_path) | |||
| elif _is_targz(from_path) or _is_tgz(from_path): | |||
| with tarfile.open(from_path, 'r:gz') as tar: | |||
| tar.extractall(path=to_path) | |||
| elif _is_tarxz(from_path): | |||
| with tarfile.open(from_path, 'r:xz') as tar: | |||
| tar.extractall(path=to_path) | |||
| elif _is_gzip(from_path): | |||
| to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) | |||
| with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: | |||
| out_f.write(zip_f.read()) | |||
| elif _is_zip(from_path): | |||
| with zipfile.ZipFile(from_path, 'r') as z: | |||
| z.extractall(to_path) | |||
| else: | |||
| raise ValueError("Extraction of {} not supported".format(from_path)) | |||
| if remove_finished: | |||
| os.remove(from_path) | |||
| def download_and_extract_archive( | |||
| url: str, | |||
| download_root: str, | |||
| extract_root: Optional[str] = None, | |||
| filename: Optional[str] = None, | |||
| md5: Optional[str] = None, | |||
| remove_finished: bool = False, | |||
| ) -> None: | |||
| download_root = os.path.expanduser(download_root) | |||
| if extract_root is None: | |||
| extract_root = download_root | |||
| if not filename: | |||
| filename = os.path.basename(url) | |||
| download_url(url, download_root, filename, md5) | |||
| archive = os.path.join(download_root, filename) | |||
| print("Extracting {} to {}".format(archive, extract_root)) | |||
| # print(archive) | |||
| # print(extract_root) | |||
| # extract_archive(archive, extract_root, remove_finished) | |||
| def iterable_to_str(iterable: Iterable) -> str: | |||
| return "'" + "', '".join([str(item) for item in iterable]) + "'" | |||
| T = TypeVar("T", str, bytes) | |||
| def verify_str_arg( | |||
| value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None, | |||
| ) -> T: | |||
| if not isinstance(value, torch._six.string_classes): | |||
| if arg is None: | |||
| msg = "Expected type str, but got type {type}." | |||
| else: | |||
| msg = "Expected type str for argument {arg}, but got type {type}." | |||
| msg = msg.format(type=type(value), arg=arg) | |||
| raise ValueError(msg) | |||
| if valid_values is None: | |||
| return value | |||
| if value not in valid_values: | |||
| if custom_msg is not None: | |||
| msg = custom_msg | |||
| else: | |||
| msg = ("Unknown value '{value}' for argument {arg}. " | |||
| "Valid values are {{{valid_values}}}.") | |||
| msg = msg.format(value=value, arg=arg, | |||
| valid_values=iterable_to_str(valid_values)) | |||
| raise ValueError(msg) | |||
| return value | |||
| @@ -0,0 +1,559 @@ | |||
| import errno | |||
| import hashlib | |||
| import os | |||
| import re | |||
| import shutil | |||
| import sys | |||
| import tempfile | |||
| # import torch | |||
| import warnings | |||
| import zipfile | |||
| from urllib.request import urlopen, Request | |||
| from urllib.parse import urlparse # noqa: F401 | |||
| try: | |||
| from tqdm.auto import tqdm # automatically select proper tqdm submodule if available | |||
| except ImportError: | |||
| try: | |||
| from tqdm import tqdm | |||
| except ImportError: | |||
| # fake tqdm if it's not installed | |||
| class tqdm(object): # type: ignore | |||
| def __init__(self, total=None, disable=False, | |||
| unit=None, unit_scale=None, unit_divisor=None): | |||
| self.total = total | |||
| self.disable = disable | |||
| self.n = 0 | |||
| # ignore unit, unit_scale, unit_divisor; they're just for real tqdm | |||
| def update(self, n): | |||
| if self.disable: | |||
| return | |||
| self.n += n | |||
| if self.total is None: | |||
| sys.stderr.write("\r{0:.1f} bytes".format(self.n)) | |||
| else: | |||
| sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) | |||
| sys.stderr.flush() | |||
| def __enter__(self): | |||
| return self | |||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||
| if self.disable: | |||
| return | |||
| sys.stderr.write('\n') | |||
| # # matches bfd8deac from resnet18-bfd8deac.pth | |||
| # HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') | |||
| # | |||
| # MASTER_BRANCH = 'master' | |||
| # ENV_TORCH_HOME = 'TORCH_HOME' | |||
| # ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' | |||
| # DEFAULT_CACHE_DIR = '~/.cache' | |||
| # VAR_DEPENDENCY = 'dependencies' | |||
| # MODULE_HUBCONF = 'hubconf.py' | |||
| # READ_DATA_CHUNK = 8192 | |||
| # _hub_dir = None | |||
| # | |||
| # | |||
| # # Copied from tools/shared/module_loader to be included in torch package | |||
| # def import_module(name, path): | |||
| # import importlib.util | |||
| # from importlib.abc import Loader | |||
| # spec = importlib.util.spec_from_file_location(name, path) | |||
| # module = importlib.util.module_from_spec(spec) | |||
| # assert isinstance(spec.loader, Loader) | |||
| # spec.loader.exec_module(module) | |||
| # return module | |||
| # | |||
| # | |||
| # def _remove_if_exists(path): | |||
| # if os.path.exists(path): | |||
| # if os.path.isfile(path): | |||
| # os.remove(path) | |||
| # else: | |||
| # shutil.rmtree(path) | |||
| # | |||
| # | |||
| # def _git_archive_link(repo_owner, repo_name, branch): | |||
| # return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch) | |||
| # | |||
| # | |||
| # def _load_attr_from_module(module, func_name): | |||
| # # Check if callable is defined in the module | |||
| # if func_name not in dir(module): | |||
| # return None | |||
| # return getattr(module, func_name) | |||
| # | |||
| # | |||
| # def _get_torch_home(): | |||
| # torch_home = os.path.expanduser( | |||
| # os.getenv(ENV_TORCH_HOME, | |||
| # os.path.join(os.getenv(ENV_XDG_CACHE_HOME, | |||
| # DEFAULT_CACHE_DIR), 'torch'))) | |||
| # return torch_home | |||
| # | |||
| # | |||
| # def _parse_repo_info(github): | |||
| # branch = MASTER_BRANCH | |||
| # if ':' in github: | |||
| # repo_info, branch = github.split(':') | |||
| # else: | |||
| # repo_info = github | |||
| # repo_owner, repo_name = repo_info.split('/') | |||
| # return repo_owner, repo_name, branch | |||
| # | |||
| # | |||
| # def _get_cache_or_reload(github, force_reload, verbose=True): | |||
| # # Setup hub_dir to save downloaded files | |||
| # hub_dir = get_dir() | |||
| # if not os.path.exists(hub_dir): | |||
| # os.makedirs(hub_dir) | |||
| # # Parse github repo information | |||
| # repo_owner, repo_name, branch = _parse_repo_info(github) | |||
| # # Github allows branch name with slash '/', | |||
| # # this causes confusion with path on both Linux and Windows. | |||
| # # Backslash is not allowed in Github branch name so no need to | |||
| # # to worry about it. | |||
| # normalized_br = branch.replace('/', '_') | |||
| # # Github renames folder repo-v1.x.x to repo-1.x.x | |||
| # # We don't know the repo name before downloading the zip file | |||
| # # and inspect name from it. | |||
| # # To check if cached repo exists, we need to normalize folder names. | |||
| # repo_dir = os.path.join(hub_dir, '_'.join([repo_owner, repo_name, normalized_br])) | |||
| # | |||
| # use_cache = (not force_reload) and os.path.exists(repo_dir) | |||
| # | |||
| # if use_cache: | |||
| # if verbose: | |||
| # sys.stderr.write('Using cache found in {}\n'.format(repo_dir)) | |||
| # else: | |||
| # cached_file = os.path.join(hub_dir, normalized_br + '.zip') | |||
| # _remove_if_exists(cached_file) | |||
| # | |||
| # url = _git_archive_link(repo_owner, repo_name, branch) | |||
| # sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file)) | |||
| # download_url_to_file(url, cached_file, progress=False) | |||
| # | |||
| # with zipfile.ZipFile(cached_file) as cached_zipfile: | |||
| # extraced_repo_name = cached_zipfile.infolist()[0].filename | |||
| # extracted_repo = os.path.join(hub_dir, extraced_repo_name) | |||
| # _remove_if_exists(extracted_repo) | |||
| # # Unzip the code and rename the base folder | |||
| # cached_zipfile.extractall(hub_dir) | |||
| # | |||
| # _remove_if_exists(cached_file) | |||
| # _remove_if_exists(repo_dir) | |||
| # shutil.move(extracted_repo, repo_dir) # rename the repo | |||
| # | |||
| # return repo_dir | |||
| # | |||
| # | |||
| # def _check_module_exists(name): | |||
| # if sys.version_info >= (3, 4): | |||
| # import importlib.util | |||
| # return importlib.util.find_spec(name) is not None | |||
| # elif sys.version_info >= (3, 3): | |||
| # # Special case for python3.3 | |||
| # import importlib.find_loader | |||
| # return importlib.find_loader(name) is not None | |||
| # else: | |||
| # # NB: Python2.7 imp.find_module() doesn't respect PEP 302, | |||
| # # it cannot find a package installed as .egg(zip) file. | |||
| # # Here we use workaround from: | |||
| # # https://stackoverflow.com/questions/28962344/imp-find-module-which-supports-zipped-eggs?lq=1 | |||
| # # Also imp doesn't handle hierarchical module names (names contains dots). | |||
| # try: | |||
| # # 1. Try imp.find_module(), which searches sys.path, but does | |||
| # # not respect PEP 302 import hooks. | |||
| # import imp | |||
| # result = imp.find_module(name) | |||
| # if result: | |||
| # return True | |||
| # except ImportError: | |||
| # pass | |||
| # path = sys.path | |||
| # for item in path: | |||
| # # 2. Scan path for import hooks. sys.path_importer_cache maps | |||
| # # path items to optional "importer" objects, that implement | |||
| # # find_module() etc. Note that path must be a subset of | |||
| # # sys.path for this to work. | |||
| # importer = sys.path_importer_cache.get(item) | |||
| # if importer: | |||
| # try: | |||
| # result = importer.find_module(name, [item]) | |||
| # if result: | |||
| # return True | |||
| # except ImportError: | |||
| # pass | |||
| # return False | |||
| # | |||
| # def _check_dependencies(m): | |||
| # dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) | |||
| # | |||
| # if dependencies is not None: | |||
| # missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] | |||
| # if len(missing_deps): | |||
| # raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps))) | |||
| # | |||
| # | |||
| # def _load_entry_from_hubconf(m, model): | |||
| # if not isinstance(model, str): | |||
| # raise ValueError('Invalid input: model should be a string of function name') | |||
| # | |||
| # # Note that if a missing dependency is imported at top level of hubconf, it will | |||
| # # throw before this function. It's a chicken and egg situation where we have to | |||
| # # load hubconf to know what're the dependencies, but to import hubconf it requires | |||
| # # a missing package. This is fine, Python will throw proper error message for users. | |||
| # _check_dependencies(m) | |||
| # | |||
| # func = _load_attr_from_module(m, model) | |||
| # | |||
| # if func is None or not callable(func): | |||
| # raise RuntimeError('Cannot find callable {} in hubconf'.format(model)) | |||
| # | |||
| # return func | |||
| # | |||
| # | |||
| # def get_dir(): | |||
| # r""" | |||
| # Get the Torch Hub cache directory used for storing downloaded models & weights. | |||
| # | |||
| # If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where | |||
| # environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. | |||
| # ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux | |||
| # filesystem layout, with a default value ``~/.cache`` if the environment | |||
| # variable is not set. | |||
| # """ | |||
| # # Issue warning to move data if old env is set | |||
| # if os.getenv('TORCH_HUB'): | |||
| # warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead') | |||
| # | |||
| # if _hub_dir is not None: | |||
| # return _hub_dir | |||
| # return os.path.join(_get_torch_home(), 'hub') | |||
| # | |||
| # | |||
| # def set_dir(d): | |||
| # r""" | |||
| # Optionally set the Torch Hub directory used to save downloaded models & weights. | |||
| # | |||
| # Args: | |||
| # d (string): path to a local folder to save downloaded models & weights. | |||
| # """ | |||
| # global _hub_dir | |||
| # _hub_dir = d | |||
| # | |||
| # | |||
| # def list(github, force_reload=False): | |||
| # r""" | |||
| # List all entrypoints available in `github` hubconf. | |||
| # | |||
| # Args: | |||
| # github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional | |||
| # tag/branch. The default branch is `master` if not specified. | |||
| # Example: 'pytorch/vision[:hub]' | |||
| # force_reload (bool, optional): whether to discard the existing cache and force a fresh download. | |||
| # Default is `False`. | |||
| # Returns: | |||
| # entrypoints: a list of available entrypoint names | |||
| # | |||
| # Example: | |||
| # >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) | |||
| # """ | |||
| # repo_dir = _get_cache_or_reload(github, force_reload, True) | |||
| # | |||
| # sys.path.insert(0, repo_dir) | |||
| # | |||
| # hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) | |||
| # | |||
| # sys.path.remove(repo_dir) | |||
| # | |||
| # # We take functions starts with '_' as internal helper functions | |||
| # entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')] | |||
| # | |||
| # return entrypoints | |||
| # | |||
| # | |||
| # def help(github, model, force_reload=False): | |||
| # r""" | |||
| # Show the docstring of entrypoint `model`. | |||
| # | |||
| # Args: | |||
| # github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional | |||
| # tag/branch. The default branch is `master` if not specified. | |||
| # Example: 'pytorch/vision[:hub]' | |||
| # model (string): a string of entrypoint name defined in repo's hubconf.py | |||
| # force_reload (bool, optional): whether to discard the existing cache and force a fresh download. | |||
| # Default is `False`. | |||
| # Example: | |||
| # >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) | |||
| # """ | |||
| # repo_dir = _get_cache_or_reload(github, force_reload, True) | |||
| # | |||
| # sys.path.insert(0, repo_dir) | |||
| # | |||
| # hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) | |||
| # | |||
| # sys.path.remove(repo_dir) | |||
| # | |||
| # entry = _load_entry_from_hubconf(hub_module, model) | |||
| # | |||
| # return entry.__doc__ | |||
| # | |||
| # | |||
| # # Ideally this should be `def load(github, model, *args, forece_reload=False, **kwargs):`, | |||
| # # but Python2 complains syntax error for it. We have to skip force_reload in function | |||
| # # signature here but detect it in kwargs instead. | |||
| # # TODO: fix it after Python2 EOL | |||
| # def load(repo_or_dir, model, *args, **kwargs): | |||
| # r""" | |||
| # Load a model from a github repo or a local directory. | |||
| # | |||
| # Note: Loading a model is the typical use case, but this can also be used to | |||
| # for loading other objects such as tokenizers, loss functions, etc. | |||
| # | |||
| # If :attr:`source` is ``'github'``, :attr:`repo_or_dir` is expected to be | |||
| # of the form ``repo_owner/repo_name[:tag_name]`` with an optional | |||
| # tag/branch. | |||
| # | |||
| # If :attr:`source` is ``'local'``, :attr:`repo_or_dir` is expected to be a | |||
| # path to a local directory. | |||
| # | |||
| # Args: | |||
| # repo_or_dir (string): repo name (``repo_owner/repo_name[:tag_name]``), | |||
| # if ``source = 'github'``; or a path to a local directory, if | |||
| # ``source = 'local'``. | |||
| # model (string): the name of a callable (entrypoint) defined in the | |||
| # repo/dir's ``hubconf.py``. | |||
| # *args (optional): the corresponding args for callable :attr:`model`. | |||
| # source (string, optional): ``'github'`` | ``'local'``. Specifies how | |||
| # ``repo_or_dir`` is to be interpreted. Default is ``'github'``. | |||
| # force_reload (bool, optional): whether to force a fresh download of | |||
| # the github repo unconditionally. Does not have any effect if | |||
| # ``source = 'local'``. Default is ``False``. | |||
| # verbose (bool, optional): If ``False``, mute messages about hitting | |||
| # local caches. Note that the message about first download cannot be | |||
| # muted. Does not have any effect if ``source = 'local'``. | |||
| # Default is ``True``. | |||
| # **kwargs (optional): the corresponding kwargs for callable | |||
| # :attr:`model`. | |||
| # | |||
| # Returns: | |||
| # The output of the :attr:`model` callable when called with the given | |||
| # ``*args`` and ``**kwargs``. | |||
| # | |||
| # Example: | |||
| # >>> # from a github repo | |||
| # >>> repo = 'pytorch/vision' | |||
| # >>> model = torch.hub.load(repo, 'resnet50', pretrained=True) | |||
| # >>> # from a local directory | |||
| # >>> path = '/some/local/path/pytorch/vision' | |||
| # >>> model = torch.hub.load(path, 'resnet50', pretrained=True) | |||
| # """ | |||
| # source = kwargs.pop('source', 'github').lower() | |||
| # force_reload = kwargs.pop('force_reload', False) | |||
| # verbose = kwargs.pop('verbose', True) | |||
| # | |||
| # if source not in ('github', 'local'): | |||
| # raise ValueError( | |||
| # f'Unknown source: "{source}". Allowed values: "github" | "local".') | |||
| # | |||
| # if source == 'github': | |||
| # repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose) | |||
| # | |||
| # model = _load_local(repo_or_dir, model, *args, **kwargs) | |||
| # return model | |||
| # | |||
| # | |||
| # def _load_local(hubconf_dir, model, *args, **kwargs): | |||
| # r""" | |||
| # Load a model from a local directory with a ``hubconf.py``. | |||
| # | |||
| # Args: | |||
| # hubconf_dir (string): path to a local directory that contains a | |||
| # ``hubconf.py``. | |||
| # model (string): name of an entrypoint defined in the directory's | |||
| # `hubconf.py`. | |||
| # *args (optional): the corresponding args for callable ``model``. | |||
| # **kwargs (optional): the corresponding kwargs for callable ``model``. | |||
| # | |||
| # Returns: | |||
| # a single model with corresponding pretrained weights. | |||
| # | |||
| # Example: | |||
| # >>> path = '/some/local/path/pytorch/vision' | |||
| # >>> model = _load_local(path, 'resnet50', pretrained=True) | |||
| # """ | |||
| # sys.path.insert(0, hubconf_dir) | |||
| # | |||
| # hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) | |||
| # hub_module = import_module(MODULE_HUBCONF, hubconf_path) | |||
| # | |||
| # entry = _load_entry_from_hubconf(hub_module, model) | |||
| # model = entry(*args, **kwargs) | |||
| # | |||
| # sys.path.remove(hubconf_dir) | |||
| # | |||
| # return model | |||
| # | |||
| # | |||
| # def download_url_to_file(url, dst, hash_prefix=None, progress=True): | |||
| # r"""Download object at the given URL to a local path. | |||
| # | |||
| # Args: | |||
| # url (string): URL of the object to download | |||
| # dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` | |||
| # hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`. | |||
| # Default: None | |||
| # progress (bool, optional): whether or not to display a progress bar to stderr | |||
| # Default: True | |||
| # | |||
| # Example: | |||
| # >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') | |||
| # | |||
| # """ | |||
| # file_size = None | |||
| # # We use a different API for python2 since urllib(2) doesn't recognize the CA | |||
| # # certificates in older Python | |||
| # req = Request(url, headers={"User-Agent": "torch.hub"}) | |||
| # u = urlopen(req) | |||
| # meta = u.info() | |||
| # if hasattr(meta, 'getheaders'): | |||
| # content_length = meta.getheaders("Content-Length") | |||
| # else: | |||
| # content_length = meta.get_all("Content-Length") | |||
| # if content_length is not None and len(content_length) > 0: | |||
| # file_size = int(content_length[0]) | |||
| # | |||
| # # We deliberately save it in a temp file and move it after | |||
| # # download is complete. This prevents a local working checkpoint | |||
| # # being overridden by a broken download. | |||
| # dst = os.path.expanduser(dst) | |||
| # dst_dir = os.path.dirname(dst) | |||
| # f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) | |||
| # | |||
| # try: | |||
| # if hash_prefix is not None: | |||
| # sha256 = hashlib.sha256() | |||
| # with tqdm(total=file_size, disable=not progress, | |||
| # unit='B', unit_scale=True, unit_divisor=1024) as pbar: | |||
| # while True: | |||
| # buffer = u.read(8192) | |||
| # if len(buffer) == 0: | |||
| # break | |||
| # f.write(buffer) | |||
| # if hash_prefix is not None: | |||
| # sha256.update(buffer) | |||
| # pbar.update(len(buffer)) | |||
| # | |||
| # f.close() | |||
| # if hash_prefix is not None: | |||
| # digest = sha256.hexdigest() | |||
| # if digest[:len(hash_prefix)] != hash_prefix: | |||
| # raise RuntimeError('invalid hash value (expected "{}", got "{}")' | |||
| # .format(hash_prefix, digest)) | |||
| # shutil.move(f.name, dst) | |||
| # finally: | |||
| # f.close() | |||
| # if os.path.exists(f.name): | |||
| # os.remove(f.name) | |||
| # | |||
| # def _download_url_to_file(url, dst, hash_prefix=None, progress=True): | |||
| # warnings.warn('torch.hub._download_url_to_file has been renamed to\ | |||
| # torch.hub.download_url_to_file to be a public API,\ | |||
| # _download_url_to_file will be removed in after 1.3 release') | |||
| # download_url_to_file(url, dst, hash_prefix, progress) | |||
| # | |||
| # # Hub used to support automatically extracts from zipfile manually compressed by users. | |||
| # # The legacy zip format expects only one file from torch.save() < 1.6 in the zip. | |||
| # # We should remove this support since zipfile is now default zipfile format for torch.save(). | |||
| # def _is_legacy_zip_format(filename): | |||
| # if zipfile.is_zipfile(filename): | |||
| # infolist = zipfile.ZipFile(filename).infolist() | |||
| # return len(infolist) == 1 and not infolist[0].is_dir() | |||
| # return False | |||
| # | |||
| # def _legacy_zip_load(filename, model_dir, map_location): | |||
| # warnings.warn('Falling back to the old format < 1.6. This support will be ' | |||
| # 'deprecated in favor of default zipfile format introduced in 1.6. ' | |||
| # 'Please redo torch.save() to save it in the new zipfile format.') | |||
| # # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. | |||
| # # We deliberately don't handle tarfile here since our legacy serialization format was in tar. | |||
| # # E.g. resnet18-5c106cde.pth which is widely used. | |||
| # with zipfile.ZipFile(filename) as f: | |||
| # members = f.infolist() | |||
| # if len(members) != 1: | |||
| # raise RuntimeError('Only one file(not dir) is allowed in the zipfile') | |||
| # f.extractall(model_dir) | |||
| # extraced_name = members[0].filename | |||
| # extracted_file = os.path.join(model_dir, extraced_name) | |||
| # return torch.load(extracted_file, map_location=map_location) | |||
| # | |||
| # def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): | |||
| # r"""Loads the Torch serialized object at the given URL. | |||
| # | |||
| # If downloaded file is a zip file, it will be automatically | |||
| # decompressed. | |||
| # | |||
| # If the object is already present in `model_dir`, it's deserialized and | |||
| # returned. | |||
| # The default value of `model_dir` is ``<hub_dir>/checkpoints`` where | |||
| # `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. | |||
| # | |||
| # Args: | |||
| # url (string): URL of the object to download | |||
| # model_dir (string, optional): directory in which to save the object | |||
| # map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) | |||
| # progress (bool, optional): whether or not to display a progress bar to stderr. | |||
| # Default: True | |||
| # check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention | |||
| # ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more | |||
| # digits of the SHA256 hash of the contents of the file. The hash is used to | |||
| # ensure unique names and to verify the contents of the file. | |||
| # Default: False | |||
| # file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set. | |||
| # | |||
| # Example: | |||
| # >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') | |||
| # | |||
| # """ | |||
| # # Issue warning to move data if old env is set | |||
| # if os.getenv('TORCH_MODEL_ZOO'): | |||
| # warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') | |||
| # | |||
| # if model_dir is None: | |||
| # hub_dir = get_dir() | |||
| # model_dir = os.path.join(hub_dir, 'checkpoints') | |||
| # | |||
| # try: | |||
| # os.makedirs(model_dir) | |||
| # except OSError as e: | |||
| # if e.errno == errno.EEXIST: | |||
| # # Directory already exists, ignore. | |||
| # pass | |||
| # else: | |||
| # # Unexpected OSError, re-raise. | |||
| # raise | |||
| # | |||
| # parts = urlparse(url) | |||
| # filename = os.path.basename(parts.path) | |||
| # if file_name is not None: | |||
| # filename = file_name | |||
| # cached_file = os.path.join(model_dir, filename) | |||
| # if not os.path.exists(cached_file): | |||
| # sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |||
| # hash_prefix = None | |||
| # if check_hash: | |||
| # r = HASH_REGEX.search(filename) # r is Optional[Match[str]] | |||
| # hash_prefix = r.group(1) if r else None | |||
| # download_url_to_file(url, cached_file, hash_prefix, progress=progress) | |||
| # | |||
| # if _is_legacy_zip_format(cached_file): | |||
| # return _legacy_zip_load(cached_file, model_dir, map_location) | |||
| # return torch.load(cached_file, map_location=map_location) | |||
| @@ -0,0 +1,139 @@ | |||
| import warnings | |||
| import os | |||
| import os.path | |||
| import numpy as np | |||
| import codecs | |||
| import string | |||
| import gzip | |||
| import lzma | |||
| from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union | |||
| from dataset_utils import download_url, download_and_extract_archive, extract_archive, verify_str_arg | |||
| # tqdm >= 4.31.1 | |||
| from tods import generate_dataset | |||
| from sklearn import preprocessing | |||
| import pandas as pd | |||
| class TODS_dataset: | |||
| resources = [] | |||
| training_file = None | |||
| testing_file = None | |||
| ground_truth_index = None | |||
| _repr_indent = None | |||
| @property | |||
| def raw_folder(self) -> str: | |||
| return os.path.join(self.root, self.__class__.__name__, 'raw') | |||
| @property | |||
| def processed_folder(self) -> str: | |||
| return os.path.join(self.root, self.__class__.__name__, 'processed') | |||
| def __init__(self, root, train, transform=None, download=True): | |||
| self.root = root | |||
| self.train = train | |||
| self.transform = self.transform_init(transform) | |||
| if download: | |||
| self.download() | |||
| pass | |||
| self.process() | |||
| def _check_exists(self) -> bool: | |||
| return (os.path.exists(os.path.join(self.processed_folder, | |||
| self.training_file)) and | |||
| os.path.exists(os.path.join(self.processed_folder, | |||
| self.testing_file))) | |||
| def download(self) -> None: | |||
| if self._check_exists(): | |||
| return | |||
| os.makedirs(self.raw_folder, exist_ok=True) | |||
| # download files | |||
| for url, md5 in self.resources: | |||
| filename = url.rpartition('/')[2] | |||
| download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) | |||
| def process(self) -> None: | |||
| pass | |||
| def process_dataframe(self) -> None: | |||
| if self.transform is None: | |||
| pass | |||
| else: | |||
| self.transform.fit(self.training_set_dataframe) | |||
| self.training_set_array = self.transform.transform(self.training_set_dataframe.values) | |||
| self.testing_set_array = self.transform.transform(self.testing_set_dataframe.values) | |||
| self.training_set_dataframe = pd.DataFrame(self.training_set_array) | |||
| self.testing_set_dataframe = pd.DataFrame(self.testing_set_array) | |||
| def transform_init(self, transform_str): | |||
| if transform_str is None: | |||
| return None | |||
| elif transform_str == 'standardscale': | |||
| return preprocessing.StandardScaler() | |||
| elif transform_str == 'normalize': | |||
| return preprocessing.Normalizer() | |||
| elif transform_str == 'minmaxscale': | |||
| return preprocessing.MinMaxScaler() | |||
| elif transform_str == 'maxabsscale': | |||
| return preprocessing.MaxAbsScaler() | |||
| elif transform_str == 'binarize': | |||
| return preprocessing.Binarizer() | |||
| else: | |||
| raise ValueError("Input parameter transform must take value of 'standardscale', 'normalize', " + | |||
| "'minmaxscale', 'maxabsscale' or 'binarize'." | |||
| ) | |||
| def to_axolotl_dataset(self): | |||
| if self.train: | |||
| return generate_dataset(self.training_set_dataframe, self.ground_truth_index) | |||
| else: | |||
| return generate_dataset(self.testing_set_dataframe, self.ground_truth_index) | |||
| def __repr__(self) -> str: | |||
| head = "Dataset " + self.__class__.__name__ | |||
| body = ["Number of datapoints: {}".format(self.__len__())] | |||
| if self.root is not None: | |||
| body.append("Root location: {}".format(self.root)) | |||
| body += self.extra_repr().splitlines() | |||
| if hasattr(self, "transforms") and self.transforms is not None: | |||
| body += [repr(self.transforms)] | |||
| lines = [head] + [" " * self._repr_indent + line for line in body] | |||
| print(self.training_set_dataframe) | |||
| return '\n'.join(lines) | |||
| def __len__(self) -> int: | |||
| return len(self.training_set_dataframe) | |||
| def extra_repr(self) -> str: | |||
| return "" | |||
| # kpi(root='./datasets', train=True) | |||
| # class yahoo5: | |||
| # | |||
| # def __init__(self): | |||
| # pass | |||
| @@ -0,0 +1,116 @@ | |||
| import os | |||
| import pandas as pd | |||
| from tods_dataset_base import TODS_dataset | |||
| from shutil import copyfile | |||
| class kpi_dataset(TODS_dataset): | |||
| resources = [ | |||
| # ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), | |||
| # ("https://github.com/datamllab/tods/blob/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), | |||
| # ("https://github.com/NetManAIOps/KPI-Anomaly-Detection/blob/master/Preliminary_dataset/train.csv", None), | |||
| ("https://hegsns.github.io/tods_datasets/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||
| ("https://hegsns.github.io/tods_datasets/kpi/TRAIN/dataset_TRAIN/datasetDoc.json", None), | |||
| # needs a server to store the dataset. | |||
| # ("https://raw.githubusercontent.com/datamllab/tods/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||
| ] | |||
| training_file = 'learningData.csv' | |||
| testing_file = 'testingData.csv' | |||
| ground_truth_index = 3 | |||
| _repr_indent = 4 | |||
| # def __init__(self, root, train, transform=None, target_transform=None, download=True): | |||
| # super().__init__(root, train, transform=None, target_transform=None, download=True) | |||
| def process(self) -> None: | |||
| print('Processing...') | |||
| os.makedirs(self.processed_folder, exist_ok=True) | |||
| os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) | |||
| training_set_fname = os.path.join(self.raw_folder, 'learningData.csv') | |||
| self.training_set_dataframe = pd.read_csv(training_set_fname) | |||
| testing_set_fname = os.path.join(self.raw_folder, 'learningData.csv') # temperarily same with training set | |||
| self.testing_set_dataframe = pd.read_csv(testing_set_fname) | |||
| self.process_dataframe() | |||
| self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) | |||
| self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) | |||
| copyfile(os.path.join(self.raw_folder, 'datasetDoc.json'), os.path.join(self.processed_folder, 'datasetDoc.json')) | |||
| print('Done!') | |||
| class yahoo_dataset(TODS_dataset): | |||
| resources = [ | |||
| # ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), | |||
| # ("https://github.com/datamllab/tods/blob/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), | |||
| # ("https://github.com/NetManAIOps/KPI-Anomaly-Detection/blob/master/Preliminary_dataset/train.csv", None), | |||
| ("https://hegsns.github.io/tods_datasets/yahoo_sub_5/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||
| ("https://hegsns.github.io/tods_datasets/yahoo_sub_5/TRAIN/dataset_TRAIN/datasetDoc.json", None), | |||
| # needs a server to store the dataset. | |||
| # ("https://raw.githubusercontent.com/datamllab/tods/master/datasets/anomaly/kpi/TRAIN/dataset_TRAIN/tables/learningData.csv", None), # it needs md5 to check if local learningData.csv is the same with online. | |||
| ] | |||
| training_file = 'learningData.csv' | |||
| testing_file = 'testingData.csv' | |||
| ground_truth_index = 7 | |||
| _repr_indent = 4 | |||
| def process(self) -> None: | |||
| print('Processing...') | |||
| os.makedirs(self.processed_folder, exist_ok=True) | |||
| os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) | |||
| training_set_fname = os.path.join(self.raw_folder, 'learningData.csv') | |||
| self.training_set_dataframe = pd.read_csv(training_set_fname) | |||
| testing_set_fname = os.path.join(self.raw_folder, 'learningData.csv') # temperarily same with training set | |||
| self.testing_set_dataframe = pd.read_csv(testing_set_fname) | |||
| self.process_dataframe() | |||
| self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) | |||
| self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) | |||
| copyfile(os.path.join(self.raw_folder, 'datasetDoc.json'), os.path.join(self.processed_folder, 'datasetDoc.json')) | |||
| print('Done!') | |||
| class NAB_dataset(TODS_dataset): | |||
| resources = [ | |||
| ("https://hegsns.github.io/tods_datasets/NAB/realTweets/labeled_Twitter_volume_AMZN.csv", None), | |||
| # it needs md5 to check if local learningData.csv is the same with online. | |||
| ("https://hegsns.github.io/tods_datasets/NAB/realTweets/labeled_Twitter_volume_AMZN.json", None), | |||
| # needs a server to store the dataset. | |||
| ] | |||
| training_file = 'learningData.csv' | |||
| testing_file = 'testingData.csv' | |||
| ground_truth_index = 2 | |||
| _repr_indent = 4 | |||
| def process(self) -> None: | |||
| print('Processing...') | |||
| os.makedirs(self.processed_folder, exist_ok=True) | |||
| os.makedirs(os.path.join(self.processed_folder, 'tables'), exist_ok=True) | |||
| training_set_fname = os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.csv') | |||
| self.training_set_dataframe = pd.read_csv(training_set_fname) | |||
| testing_set_fname = os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.csv') # temperarily same with training set | |||
| self.testing_set_dataframe = pd.read_csv(testing_set_fname) | |||
| self.process_dataframe() | |||
| self.training_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.training_file)) | |||
| self.testing_set_dataframe.to_csv(os.path.join(self.processed_folder, 'tables', self.testing_file)) | |||
| copyfile(os.path.join(self.raw_folder, 'labeled_Twitter_volume_AMZN.json'), | |||
| os.path.join(self.processed_folder, 'datasetDoc.json')) | |||
| print('Done!') | |||
| # kpi_dataset(root='./datasets', train=True, transform='binarize') | |||
| # yahoo_dataset(root='./datasets', train=True, transform='binarize') | |||
| # NAB_dataset(root='./datasets', train=True, transform='binarize') | |||
| @@ -1 +0,0 @@ | |||
| Subproject commit af54e6970476a081bf0cd65990c9f56a1200d8a2 | |||
| @@ -1 +0,0 @@ | |||
| Subproject commit 046b20d2f6d4543dcbe18f0a1d4bcbb1f61cf518 | |||
| @@ -1 +0,0 @@ | |||
| Subproject commit 70aeefed6b7307941581357c4b7858bb3f88e1da | |||
| @@ -0,0 +1,116 @@ | |||
| import os | |||
| import typing | |||
| import numpy | |||
| import pandas | |||
| from d3m import container, exceptions, utils as d3m_utils | |||
| from d3m.metadata import base as metadata_base, hyperparams | |||
| from d3m.base import primitives | |||
| __all__ = ('FixedSplitDatasetSplitPrimitive',) | |||
| class Hyperparams(hyperparams.Hyperparams): | |||
| primary_index_values = hyperparams.Set( | |||
| elements=hyperparams.Hyperparameter[str](''), | |||
| default=(), | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description='A set of primary index values of the main resource belonging to the test (score) split. Cannot be set together with "row_indices".', | |||
| ) | |||
| row_indices = hyperparams.Set( | |||
| elements=hyperparams.Hyperparameter[int](-1), | |||
| default=(), | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description='A set of row indices of the main resource belonging to the test (score) split. Cannot be set together with "primary_index_values".', | |||
| ) | |||
| delete_recursive = hyperparams.Hyperparameter[bool]( | |||
| default=False, | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", | |||
| ) | |||
| class FixedSplitDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||
| """ | |||
| A primitive which splits a tabular Dataset in a way that uses for the test | |||
| (score) split a fixed list of primary index values or row indices of the main | |||
| resource to be used. All other rows are added used for the train split. | |||
| """ | |||
| metadata = metadata_base.PrimitiveMetadata( | |||
| { | |||
| 'id': '1654f000-2178-4520-be4c-a95bc26b8d3a', | |||
| 'version': '0.1.0', | |||
| 'name': "Fixed split tabular dataset splits", | |||
| 'python_path': 'd3m.primitives.tods.evaluation.fixed_split_dataset_split', | |||
| 'source': { | |||
| 'name': "DATALab@TexasA&M University", | |||
| 'contact': 'mailto:mitar.commonprimitives@tnode.com', | |||
| 'uris': [ | |||
| 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/fixed_split.py', | |||
| 'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||
| ], | |||
| }, | |||
| 'algorithm_types': [ | |||
| metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||
| ], | |||
| 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||
| }, | |||
| ) | |||
| def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||
| # This should be handled by "Set" hyper-parameter, but we check it here again just to be sure. | |||
| if d3m_utils.has_duplicates(self.hyperparams['primary_index_values']): | |||
| raise exceptions.InvalidArgumentValueError("\"primary_index_values\" hyper-parameter has duplicate values.") | |||
| if d3m_utils.has_duplicates(self.hyperparams['row_indices']): | |||
| raise exceptions.InvalidArgumentValueError("\"row_indices\" hyper-parameter has duplicate values.") | |||
| if self.hyperparams['primary_index_values'] and self.hyperparams['row_indices']: | |||
| raise exceptions.InvalidArgumentValueError("Both \"primary_index_values\" and \"row_indices\" cannot be provided.") | |||
| if self.hyperparams['primary_index_values']: | |||
| primary_index_values = numpy.array(self.hyperparams['primary_index_values']) | |||
| index_columns = dataset.metadata.get_index_columns(at=(main_resource_id,)) | |||
| if not index_columns: | |||
| raise exceptions.InvalidArgumentValueError("Cannot find index columns in the main resource of the dataset, but \"primary_index_values\" is provided.") | |||
| main_resource = dataset[main_resource_id] | |||
| # We reset the index so that the index corresponds to row indices. | |||
| main_resource = main_resource.reset_index(drop=True) | |||
| # We use just the "d3mIndex" column and ignore multi-key indices. | |||
| # This works for now because it seems that every current multi-key | |||
| # dataset in fact has an unique value in "d3mIndex" alone. | |||
| # See: https://gitlab.datadrivendiscovery.org/MIT-LL/d3m_data_supply/issues/117 | |||
| index_column = index_columns[0] | |||
| score_data = numpy.array(main_resource.loc[main_resource.iloc[:, index_column].isin(primary_index_values)].index) | |||
| score_data_set = set(score_data) | |||
| assert len(score_data) == len(score_data_set), (len(score_data), len(score_data_set)) | |||
| if len(score_data) != len(primary_index_values): | |||
| raise exceptions.InvalidArgumentValueError("\"primary_index_values\" contains values which do not exist.") | |||
| else: | |||
| score_data = numpy.array(self.hyperparams['row_indices']) | |||
| score_data_set = set(score_data) | |||
| all_data_set = set(numpy.arange(len(attributes))) | |||
| if not score_data_set <= all_data_set: | |||
| raise exceptions.InvalidArgumentValueError("\"row_indices\" contains indices which do not exist, e.g., {indices}.".format( | |||
| indices=sorted(score_data_set - all_data_set)[:5], | |||
| )) | |||
| train_data = [] | |||
| for i in numpy.arange(len(attributes)): | |||
| if i not in score_data_set: | |||
| train_data.append(i) | |||
| assert len(train_data) + len(score_data) == len(attributes), (len(train_data), len(score_data), len(attributes)) | |||
| return [(numpy.array(train_data), score_data)] | |||
| @@ -0,0 +1,87 @@ | |||
| import os | |||
| import typing | |||
| import numpy | |||
| import pandas | |||
| from sklearn import model_selection | |||
| from d3m import container, exceptions, utils as d3m_utils | |||
| from d3m.metadata import base as metadata_base, hyperparams | |||
| from d3m.base import primitives | |||
| __all__ = ('KFoldDatasetSplitPrimitive',) | |||
| class Hyperparams(hyperparams.Hyperparams): | |||
| number_of_folds = hyperparams.Bounded[int]( | |||
| lower=2, | |||
| upper=None, | |||
| default=5, | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description="Number of folds for k-folds cross-validation.", | |||
| ) | |||
| stratified = hyperparams.UniformBool( | |||
| default=False, | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description="Do stratified folds. The folds are made by preserving the percentage of samples for each class.", | |||
| ) | |||
| shuffle = hyperparams.UniformBool( | |||
| default=False, | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description="Whether to shuffle the data before splitting into batches.", | |||
| ) | |||
| delete_recursive = hyperparams.Hyperparameter[bool]( | |||
| default=False, | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", | |||
| ) | |||
| class KFoldDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||
| """ | |||
| A primitive which splits a tabular Dataset for k-fold cross-validation. | |||
| """ | |||
| __author__ = 'Mingjie Sun <sunmj15@gmail.com>' | |||
| metadata = metadata_base.PrimitiveMetadata( | |||
| { | |||
| 'id': 'bfedaf3a-6dd0-4a83-ad83-3a50fe882bf8', | |||
| 'version': '0.1.0', | |||
| 'name': "K-fold cross-validation tabular dataset splits", | |||
| 'python_path': 'd3m.primitives.tods.evaluation.kfold_dataset_split', | |||
| 'source': { | |||
| 'name': 'DATALab@Texas A&M University', | |||
| 'contact': 'mailto:sunmj15@gmail.com', | |||
| 'uris': [ | |||
| 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/kfold_split.py', | |||
| 'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||
| ], | |||
| }, | |||
| 'algorithm_types': [ | |||
| metadata_base.PrimitiveAlgorithmType.K_FOLD, | |||
| metadata_base.PrimitiveAlgorithmType.CROSS_VALIDATION, | |||
| metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||
| ], | |||
| 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||
| }, | |||
| ) | |||
| def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||
| if self.hyperparams['stratified']: | |||
| if not len(targets.columns): | |||
| raise exceptions.InvalidArgumentValueError("Stratified split is requested, but no target columns found.") | |||
| k_fold = model_selection.StratifiedKFold( | |||
| n_splits=self.hyperparams['number_of_folds'], | |||
| shuffle=self.hyperparams['shuffle'], | |||
| random_state=self._random_state, | |||
| ) | |||
| else: | |||
| k_fold = model_selection.KFold( | |||
| n_splits=self.hyperparams['number_of_folds'], | |||
| shuffle=self.hyperparams['shuffle'], | |||
| random_state=self._random_state, | |||
| ) | |||
| return list(k_fold.split(attributes, targets)) | |||
| @@ -0,0 +1,187 @@ | |||
| import os | |||
| import uuid | |||
| import typing | |||
| from collections import OrderedDict | |||
| import numpy | |||
| import pandas | |||
| from sklearn import model_selection | |||
| from d3m import container, exceptions, utils as d3m_utils | |||
| from d3m.metadata import base as metadata_base, hyperparams | |||
| from d3m.base import primitives | |||
| import utils | |||
| __all__ = ('KFoldTimeSeriesSplitPrimitive',) | |||
| class Hyperparams(hyperparams.Hyperparams): | |||
| number_of_folds = hyperparams.Bounded[int]( | |||
| lower=2, | |||
| upper=None, | |||
| default=5, | |||
| semantic_types=[ | |||
| 'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||
| ], | |||
| description="Number of folds for k-folds cross-validation.", | |||
| ) | |||
| number_of_window_folds = hyperparams.Union[typing.Union[int, None]]( | |||
| configuration=OrderedDict( | |||
| fixed=hyperparams.Bounded[int]( | |||
| lower=1, | |||
| upper=None, | |||
| default=1, | |||
| description="Number of folds in train set (window). These folds come directly " | |||
| "before test set (streaming window).", | |||
| ), | |||
| all_records=hyperparams.Constant( | |||
| default=None, | |||
| description="Number of folds in train set (window) = maximum number possible.", | |||
| ), | |||
| ), | |||
| default='all_records', | |||
| semantic_types=[ | |||
| 'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||
| ], | |||
| description="Maximum size for a single training set.", | |||
| ) | |||
| time_column_index = hyperparams.Union[typing.Union[int, None]]( | |||
| configuration=OrderedDict( | |||
| fixed=hyperparams.Bounded[int]( | |||
| lower=1, | |||
| upper=None, | |||
| default=1, | |||
| description="Specific column that contains the time index", | |||
| ), | |||
| one_column=hyperparams.Constant( | |||
| default=None, | |||
| description="Only one column contains a time index. " | |||
| "It is detected automatically using semantic types.", | |||
| ), | |||
| ), | |||
| default='one_column', | |||
| semantic_types=[ | |||
| 'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||
| ], | |||
| description="Column index to use as datetime index. " | |||
| "If None, it is required that only one column with time column role semantic type is " | |||
| "present and otherwise an exception is raised. " | |||
| "If column index specified is not a datetime column an exception is" | |||
| "also raised.", | |||
| ) | |||
| fuzzy_time_parsing = hyperparams.UniformBool( | |||
| default=True, | |||
| semantic_types=[ | |||
| 'https://metadata.datadrivendiscovery.org/types/ControlParameter' | |||
| ], | |||
| description="Use fuzzy time parsing.", | |||
| ) | |||
| class KFoldTimeSeriesSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||
| """ | |||
| A primitive which splits a tabular time-series Dataset for k-fold cross-validation. | |||
| Primitive sorts the time column so care should be taken to assure sorting of a | |||
| column is reasonable. E.g., if column is not numeric but of string structural type, | |||
| strings should be formatted so that sorting by them also sorts by time. | |||
| """ | |||
| __author__ = 'Distil' | |||
| __version__ = '0.3.0' | |||
| __contact__ = 'mailto:jeffrey.gleason@yonder.co' | |||
| metadata = metadata_base.PrimitiveMetadata( | |||
| { | |||
| 'id': '002f9ad1-46e3-40f4-89ed-eeffbb3a102b', | |||
| 'version': __version__, | |||
| 'name': "K-fold cross-validation timeseries dataset splits", | |||
| 'python_path': 'd3m.primitives.tods.evaluation.kfold_time_series_split', | |||
| 'source': { | |||
| 'name': 'DATALab@Texas A&M University', | |||
| 'contact': __contact__, | |||
| 'uris': [ | |||
| 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/kfold_split_timeseries.py', | |||
| 'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||
| ], | |||
| }, | |||
| 'algorithm_types': [ | |||
| metadata_base.PrimitiveAlgorithmType.K_FOLD, | |||
| metadata_base.PrimitiveAlgorithmType.CROSS_VALIDATION, | |||
| metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||
| ], | |||
| 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||
| }, | |||
| ) | |||
| def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||
| time_column_indices = dataset.metadata.list_columns_with_semantic_types(['https://metadata.datadrivendiscovery.org/types/Time'], at=(main_resource_id,)) | |||
| attribute_column_indices = dataset.metadata.list_columns_with_semantic_types(['https://metadata.datadrivendiscovery.org/types/Attribute'], at=(main_resource_id,)) | |||
| # We want only time columns which are also attributes. | |||
| time_column_indices = [time_column_index for time_column_index in time_column_indices if time_column_index in attribute_column_indices] | |||
| if self.hyperparams['time_column_index'] is None: | |||
| if len(time_column_indices) != 1: | |||
| raise exceptions.InvalidArgumentValueError( | |||
| "If \"time_column_index\" hyper-parameter is \"None\", it is required that exactly one column with time column role semantic type is present.", | |||
| ) | |||
| else: | |||
| # We know it exists because "time_column_indices" is a subset of "attribute_column_indices". | |||
| time_column_index = attribute_column_indices.index( | |||
| time_column_indices[0], | |||
| ) | |||
| else: | |||
| if self.hyperparams['time_column_index'] not in time_column_indices: | |||
| raise exceptions.InvalidArgumentValueError( | |||
| "Time column index specified does not have a time column role semantic type.", | |||
| ) | |||
| else: | |||
| time_column_index = attribute_column_indices.index( | |||
| self.hyperparams['time_column_index'], | |||
| ) | |||
| # We first reset index. | |||
| attributes = attributes.reset_index(drop=True) | |||
| # Then convert datetime column to consistent datetime representation | |||
| attributes.insert( | |||
| loc=0, | |||
| column=uuid.uuid4(), # use uuid to ensure we are inserting a new column name | |||
| value=self._parse_time_data( | |||
| attributes, time_column_index, self.hyperparams['fuzzy_time_parsing'], | |||
| ), | |||
| ) | |||
| # Then sort dataframe by new datetime column. Index contains original row order. | |||
| attributes = attributes.sort_values(by=attributes.columns[0]) | |||
| # Remove datetime representation used for sorting (primitives might choose to parse this str col differently). | |||
| attributes = attributes.drop(attributes.columns[0], axis=1) | |||
| max_train_size: typing.Optional[int] = None | |||
| if self.hyperparams['number_of_window_folds'] is not None: | |||
| max_train_size = int(attributes.shape[0] * self.hyperparams['number_of_window_folds'] / self.hyperparams['number_of_folds']) | |||
| k_fold = model_selection.TimeSeriesSplit( | |||
| n_splits=self.hyperparams['number_of_folds'], | |||
| max_train_size=max_train_size | |||
| ) | |||
| # We sorted "attributes" so we have to map indices on sorted "attributes" back to original | |||
| # indices. We do that by using DataFrame's index which contains original row order. | |||
| return [ | |||
| ( | |||
| numpy.array([attributes.index[val] for val in train]), | |||
| numpy.array([attributes.index[val] for val in test]), | |||
| ) | |||
| for train, test in k_fold.split(attributes) | |||
| ] | |||
| @classmethod | |||
| def _parse_time_data(cls, inputs: container.DataFrame, column_index: metadata_base.SimpleSelectorSegment, fuzzy: bool) -> typing.List[float]: | |||
| return [ | |||
| utils.parse_datetime_to_float(value, fuzzy=fuzzy) | |||
| for value in inputs.iloc[:, column_index] | |||
| ] | |||
| @@ -0,0 +1,52 @@ | |||
| import os | |||
| import typing | |||
| import numpy | |||
| import pandas | |||
| from d3m import container, utils as d3m_utils | |||
| from d3m.metadata import base as metadata_base, hyperparams | |||
| from d3m.base import primitives | |||
| __all__ = ('NoSplitDatasetSplitPrimitive',) | |||
| class Hyperparams(hyperparams.Hyperparams): | |||
| pass | |||
| class NoSplitDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||
| """ | |||
| A primitive which splits a tabular Dataset in a way that for all splits it | |||
| produces the same (full) Dataset. Useful for unsupervised learning tasks. . | |||
| """ | |||
| metadata = metadata_base.PrimitiveMetadata( | |||
| { | |||
| 'id': '48c683ad-da9e-48cf-b3a0-7394dba5e5d2', | |||
| 'version': '0.1.0', | |||
| 'name': "No-split tabular dataset splits", | |||
| 'python_path': 'd3m.primitives.tods.evaluation.no_split_dataset_split', | |||
| 'source': { | |||
| 'name': 'DATALab@Texas A&M University', | |||
| 'contact': 'mailto:mitar.commonprimitives@tnode.com', | |||
| 'uris': [ | |||
| 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/no_split.py', | |||
| 'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||
| ], | |||
| }, | |||
| 'algorithm_types': [ | |||
| metadata_base.PrimitiveAlgorithmType.IDENTITY_FUNCTION, | |||
| metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||
| ], | |||
| 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||
| }, | |||
| ) | |||
| def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||
| # We still go through the whole splitting process to assure full compatibility | |||
| # (and error conditions) of a regular split, but we use all data for both splits. | |||
| all_data = numpy.arange(len(attributes)) | |||
| return [(all_data, all_data)] | |||
| @@ -0,0 +1,160 @@ | |||
| import copy | |||
| import os | |||
| import typing | |||
| from d3m import container, exceptions, utils as d3m_utils | |||
| from d3m.metadata import base as metadata_base, hyperparams | |||
| from d3m.primitive_interfaces import base, transformer | |||
| Inputs = container.List | |||
| Outputs = container.List | |||
| class Hyperparams(hyperparams.Hyperparams): | |||
| match_logic = hyperparams.Enumeration( | |||
| values=['all', 'any'], | |||
| default='any', | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description="Should a column have all of semantic types in \"semantic_types\" to be redacted, or any of them?", | |||
| ) | |||
| semantic_types = hyperparams.Set( | |||
| elements=hyperparams.Hyperparameter[str](''), | |||
| default=(), | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description="Redact columns with these semantic types. Only columns having semantic types listed here will be operated on, based on \"match_logic\".", | |||
| ) | |||
| add_semantic_types = hyperparams.Set( | |||
| elements=hyperparams.Hyperparameter[str](''), | |||
| default=(), | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description="Semantic types to add to redacted columns. All listed semantic types will be added to all columns which were redacted.", | |||
| ) | |||
| # TODO: Make clear the assumption that both container type (List) and Datasets should have metadata. | |||
| # Primitive is modifying metadata of Datasets, while there is officially no reason for them | |||
| # to really have metadata: metadata is stored available on the input container type, not | |||
| # values inside it. | |||
| class RedactColumnsPrimitive(transformer.TransformerPrimitiveBase[Inputs, Outputs, Hyperparams]): | |||
| """ | |||
| A primitive which takes as an input a list of ``Dataset`` objects and redacts values of all columns matching | |||
| a given semantic type or types. | |||
| Redaction is done by setting all values in a redacted column to an empty string. | |||
| It operates only on DataFrame resources inside datasets. | |||
| """ | |||
| metadata = metadata_base.PrimitiveMetadata( | |||
| { | |||
| 'id': '744c4090-e2f6-489e-8efc-8b1e051bfad6', | |||
| 'version': '0.2.0', | |||
| 'name': "Redact columns for evaluation", | |||
| 'python_path': 'd3m.primitives.tods.evaluation.redact_columns', | |||
| 'source': { | |||
| 'name': 'DATALab@Texas A&M University', | |||
| 'contact': 'mailto:sunmj15@gmail.com', | |||
| 'uris': [ | |||
| 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/redact_columns.py', | |||
| 'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||
| ], | |||
| }, | |||
| 'installation': [{ | |||
| 'type': metadata_base.PrimitiveInstallationType.PIP, | |||
| 'package_uri': 'git+https://gitlab.com/datadrivendiscovery/common-primitives.git@{git_commit}#egg=common_primitives'.format( | |||
| git_commit=d3m_utils.current_git_commit(os.path.dirname(__file__)), | |||
| ), | |||
| }], | |||
| 'algorithm_types': [ | |||
| metadata_base.PrimitiveAlgorithmType.DATA_CONVERSION, | |||
| ], | |||
| 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||
| }, | |||
| ) | |||
| def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> base.CallResult[Outputs]: | |||
| output_datasets = container.List(generate_metadata=True) | |||
| for dataset in inputs: | |||
| resources = {} | |||
| metadata = dataset.metadata | |||
| for resource_id, resource in dataset.items(): | |||
| if not isinstance(resource, container.DataFrame): | |||
| resources[resource_id] = resource | |||
| continue | |||
| columns_to_redact = self._get_columns_to_redact(metadata, (resource_id,)) | |||
| if not columns_to_redact: | |||
| resources[resource_id] = resource | |||
| continue | |||
| resource = copy.copy(resource) | |||
| for column_index in columns_to_redact: | |||
| column_metadata = dataset.metadata.query((resource_id, metadata_base.ALL_ELEMENTS, column_index)) | |||
| if 'structural_type' in column_metadata and issubclass(column_metadata['structural_type'], str): | |||
| resource.iloc[:, column_index] = '' | |||
| else: | |||
| raise TypeError("Primitive can operate only on columns with structural type \"str\", not \"{type}\".".format( | |||
| type=column_metadata.get('structural_type', None), | |||
| )) | |||
| metadata = self._update_metadata(metadata, resource_id, column_index, ()) | |||
| resources[resource_id] = resource | |||
| dataset = container.Dataset(resources, metadata) | |||
| output_datasets.append(dataset) | |||
| output_datasets.metadata = metadata_base.DataMetadata({ | |||
| 'schema': metadata_base.CONTAINER_SCHEMA_VERSION, | |||
| 'structural_type': container.List, | |||
| 'dimension': { | |||
| 'length': len(output_datasets), | |||
| }, | |||
| }) | |||
| # We update metadata based on metadata of each dataset. | |||
| # TODO: In the future this might be done automatically by generate_metadata. | |||
| # See: https://gitlab.com/datadrivendiscovery/d3m/issues/119 | |||
| for index, dataset in enumerate(output_datasets): | |||
| output_datasets.metadata = dataset.metadata.copy_to(output_datasets.metadata, (), (index,)) | |||
| return base.CallResult(output_datasets) | |||
| def _get_columns_to_redact(self, inputs_metadata: metadata_base.DataMetadata, at: metadata_base.Selector) -> typing.Sequence[int]: | |||
| columns = [] | |||
| for element in inputs_metadata.get_elements(list(at) + [metadata_base.ALL_ELEMENTS]): | |||
| semantic_types = inputs_metadata.query(list(at) + [metadata_base.ALL_ELEMENTS, element]).get('semantic_types', ()) | |||
| # TODO: Should we handle inheritance between semantic types here? | |||
| if self.hyperparams['match_logic'] == 'all': | |||
| matched = all(semantic_type in semantic_types for semantic_type in self.hyperparams['semantic_types']) | |||
| elif self.hyperparams['match_logic'] == 'any': | |||
| matched = any(semantic_type in semantic_types for semantic_type in self.hyperparams['semantic_types']) | |||
| else: | |||
| raise exceptions.UnexpectedValueError("Unknown value of hyper-parameter \"match_logic\": {value}".format(value=self.hyperparams['match_logic'])) | |||
| if matched: | |||
| if element is metadata_base.ALL_ELEMENTS: | |||
| return list(range(inputs_metadata.query(list(at) + [metadata_base.ALL_ELEMENTS]).get('dimension', {}).get('length', 0))) | |||
| else: | |||
| columns.append(typing.cast(int, element)) | |||
| return columns | |||
| def _update_metadata( | |||
| self, inputs_metadata: metadata_base.DataMetadata, resource_id: metadata_base.SelectorSegment, | |||
| column_index: int, at: metadata_base.Selector, | |||
| ) -> metadata_base.DataMetadata: | |||
| outputs_metadata = inputs_metadata | |||
| for semantic_type in self.hyperparams['add_semantic_types']: | |||
| outputs_metadata = outputs_metadata.add_semantic_type(tuple(at) + (resource_id, metadata_base.ALL_ELEMENTS, column_index), semantic_type) | |||
| return outputs_metadata | |||
| @@ -0,0 +1,88 @@ | |||
| import os | |||
| import typing | |||
| import numpy | |||
| import pandas | |||
| from sklearn import model_selection | |||
| from d3m import container, exceptions, utils as d3m_utils | |||
| from d3m.metadata import base as metadata_base, hyperparams | |||
| from d3m.base import primitives | |||
| __all__ = ('TrainScoreDatasetSplitPrimitive',) | |||
| class Hyperparams(hyperparams.Hyperparams): | |||
| train_score_ratio = hyperparams.Uniform( | |||
| lower=0, | |||
| upper=1, | |||
| default=0.75, | |||
| upper_inclusive=True, | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description="The ratio between the train and score data and represents the proportion of the Dataset to include in the train split. The rest is included in the score split.", | |||
| ) | |||
| stratified = hyperparams.UniformBool( | |||
| default=False, | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description="Do stratified folds. The folds are made by preserving the percentage of samples for each class.", | |||
| ) | |||
| shuffle = hyperparams.UniformBool( | |||
| default=False, | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description="Whether to shuffle the data before splitting into batches.", | |||
| ) | |||
| delete_recursive = hyperparams.Hyperparameter[bool]( | |||
| default=False, | |||
| semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
| description="Delete rows in other resources/tables which are not needed for rows left in the dataset entry point resource/table.", | |||
| ) | |||
| class TrainScoreDatasetSplitPrimitive(primitives.TabularSplitPrimitiveBase[Hyperparams]): | |||
| """ | |||
| A primitive which splits a tabular Dataset into random train and score subsets. | |||
| """ | |||
| metadata = metadata_base.PrimitiveMetadata( | |||
| { | |||
| 'id': '3fcc6dc4-6681-4c86-948e-066d14e7d803', | |||
| 'version': '0.1.0', | |||
| 'name': "Train-score tabular dataset splits", | |||
| 'python_path': 'd3m.primitives.tods.evaluation.train_score_dataset_split', | |||
| 'source': { | |||
| 'name': 'DATALab@Texas A&M University', | |||
| 'contact': 'mailto:mitar.commonprimitives@tnode.com', | |||
| 'uris': [ | |||
| 'https://gitlab.com/datadrivendiscovery/common-primitives/blob/master/common_primitives/train_score_split.py', | |||
| 'https://gitlab.com/datadrivendiscovery/common-primitives.git', | |||
| ], | |||
| }, | |||
| 'installation': [{ | |||
| 'type': metadata_base.PrimitiveInstallationType.PIP, | |||
| 'package_uri': 'git+https://gitlab.com/datadrivendiscovery/common-primitives.git@{git_commit}#egg=common_primitives'.format( | |||
| git_commit=d3m_utils.current_git_commit(os.path.dirname(__file__)), | |||
| ), | |||
| }], | |||
| 'algorithm_types': [ | |||
| metadata_base.PrimitiveAlgorithmType.HOLDOUT, | |||
| metadata_base.PrimitiveAlgorithmType.DATA_SPLITTING, | |||
| ], | |||
| 'primitive_family': metadata_base.PrimitiveFamily.EVALUATION, | |||
| }, | |||
| ) | |||
| def _get_splits(self, attributes: pandas.DataFrame, targets: pandas.DataFrame, dataset: container.Dataset, main_resource_id: str) -> typing.List[typing.Tuple[numpy.ndarray, numpy.ndarray]]: | |||
| if self.hyperparams['stratified'] and not len(targets.columns): | |||
| raise exceptions.InvalidArgumentValueError("Stratified split is requested, but no target columns found.") | |||
| train_data, score_data = model_selection.train_test_split( | |||
| numpy.arange(len(attributes)), | |||
| test_size=None, | |||
| train_size=self.hyperparams['train_score_ratio'], | |||
| random_state=self._random_state, | |||
| shuffle=self.hyperparams['shuffle'], | |||
| stratify=targets if self.hyperparams['stratified'] else None, | |||
| ) | |||
| return [(train_data, score_data)] | |||
| @@ -0,0 +1,192 @@ | |||
| import datetime | |||
| import logging | |||
| import typing | |||
| import dateutil.parser | |||
| import numpy | |||
| from d3m import container, deprecate | |||
| from d3m.base import utils as base_utils | |||
| from d3m.metadata import base as metadata_base | |||
| logger = logging.getLogger(__name__) | |||
| DEFAULT_DATETIME = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) | |||
| @deprecate.function(message="it should not be used anymore") | |||
| def copy_elements_metadata(source_metadata: metadata_base.Metadata, target_metadata: metadata_base.DataMetadata, from_selector: metadata_base.Selector, | |||
| to_selector: metadata_base.Selector = (), *, ignore_all_elements: bool = False, check: bool = True, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
| return source_metadata._copy_elements_metadata(target_metadata, list(from_selector), list(to_selector), [], ignore_all_elements) | |||
| @deprecate.function(message="use Metadata.copy_to method instead") | |||
| def copy_metadata(source_metadata: metadata_base.Metadata, target_metadata: metadata_base.DataMetadata, from_selector: metadata_base.Selector, | |||
| to_selector: metadata_base.Selector = (), *, ignore_all_elements: bool = False, check: bool = True, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
| return source_metadata.copy_to(target_metadata, from_selector, to_selector, ignore_all_elements=ignore_all_elements) | |||
| @deprecate.function(message="use DataFrame.select_columns method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def select_columns(inputs: container.DataFrame, columns: typing.Sequence[metadata_base.SimpleSelectorSegment], *, | |||
| source: typing.Any = None) -> container.DataFrame: | |||
| return inputs.select_columns(columns) | |||
| @deprecate.function(message="use DataMetadata.select_columns method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def select_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns: typing.Sequence[metadata_base.SimpleSelectorSegment], *, | |||
| source: typing.Any = None) -> metadata_base.DataMetadata: | |||
| return inputs_metadata.select_columns(columns) | |||
| @deprecate.function(message="use DataMetadata.list_columns_with_semantic_types method instead") | |||
| def list_columns_with_semantic_types(metadata: metadata_base.DataMetadata, semantic_types: typing.Sequence[str], *, | |||
| at: metadata_base.Selector = ()) -> typing.Sequence[int]: | |||
| return metadata.list_columns_with_semantic_types(semantic_types, at=at) | |||
| @deprecate.function(message="use DataMetadata.list_columns_with_structural_types method instead") | |||
| def list_columns_with_structural_types(metadata: metadata_base.DataMetadata, structural_types: typing.Union[typing.Callable, typing.Sequence[typing.Union[str, type]]], *, | |||
| at: metadata_base.Selector = ()) -> typing.Sequence[int]: | |||
| return metadata.list_columns_with_structural_types(structural_types, at=at) | |||
| @deprecate.function(message="use DataFrame.remove_columns method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def remove_columns(inputs: container.DataFrame, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> container.DataFrame: | |||
| return inputs.remove_columns(column_indices) | |||
| @deprecate.function(message="use DataMetadata.remove_columns method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def remove_columns_metadata(inputs_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
| return inputs_metadata.remove_columns(column_indices) | |||
| @deprecate.function(message="use DataFrame.append_columns method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def append_columns(left: container.DataFrame, right: container.DataFrame, *, use_right_metadata: bool = False, source: typing.Any = None) -> container.DataFrame: | |||
| return left.append_columns(right, use_right_metadata=use_right_metadata) | |||
| @deprecate.function(message="use DataMetadata.append_columns method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def append_columns_metadata(left_metadata: metadata_base.DataMetadata, right_metadata: metadata_base.DataMetadata, use_right_metadata: bool = False, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
| return left_metadata.append_columns(right_metadata, use_right_metadata=use_right_metadata) | |||
| @deprecate.function(message="use DataFrame.insert_columns method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def insert_columns(inputs: container.DataFrame, columns: container.DataFrame, at_column_index: int, *, source: typing.Any = None) -> container.DataFrame: | |||
| return inputs.insert_columns(columns, at_column_index) | |||
| @deprecate.function(message="use DataMetadata.insert_columns method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def insert_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns_metadata: metadata_base.DataMetadata, at_column_index: int, *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
| return inputs_metadata.insert_columns(columns_metadata, at_column_index) | |||
| @deprecate.function(message="use DataFrame.replace_columns method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def replace_columns(inputs: container.DataFrame, columns: container.DataFrame, column_indices: typing.Sequence[int], *, copy: bool = True, source: typing.Any = None) -> container.DataFrame: | |||
| return inputs.replace_columns(columns, column_indices, copy=copy) | |||
| @deprecate.function(message="use DataMetadata.replace_columns method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def replace_columns_metadata(inputs_metadata: metadata_base.DataMetadata, columns_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
| return inputs_metadata.replace_columns(columns_metadata, column_indices) | |||
| @deprecate.function(message="use DataMetadata.get_index_columns method instead") | |||
| def get_index_columns(metadata: metadata_base.DataMetadata, *, at: metadata_base.Selector = ()) -> typing.Sequence[int]: | |||
| return metadata.get_index_columns(at=at) | |||
| @deprecate.function(message="use DataFrame.horizontal_concat method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def horizontal_concat(left: container.DataFrame, right: container.DataFrame, *, use_index: bool = True, | |||
| remove_second_index: bool = True, use_right_metadata: bool = False, source: typing.Any = None) -> container.DataFrame: | |||
| return left.horizontal_concat(right, use_index=use_index, remove_second_index=remove_second_index, use_right_metadata=use_right_metadata) | |||
| @deprecate.function(message="use DataMetadata.horizontal_concat method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def horizontal_concat_metadata(left_metadata: metadata_base.DataMetadata, right_metadata: metadata_base.DataMetadata, *, use_index: bool = True, | |||
| remove_second_index: bool = True, use_right_metadata: bool = False, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
| return left_metadata.horizontal_concat(right_metadata, use_index=use_index, remove_second_index=remove_second_index, use_right_metadata=use_right_metadata) | |||
| @deprecate.function(message="use d3m.base.utils.get_columns_to_use function instead") | |||
| def get_columns_to_use(metadata: metadata_base.DataMetadata, use_columns: typing.Sequence[int], exclude_columns: typing.Sequence[int], | |||
| can_use_column: typing.Callable) -> typing.Tuple[typing.List[int], typing.List[int]]: | |||
| return base_utils.get_columns_to_use(metadata, use_columns, exclude_columns, can_use_column) | |||
| @deprecate.function(message="use d3m.base.utils.combine_columns function instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def combine_columns(return_result: str, add_index_columns: bool, inputs: container.DataFrame, column_indices: typing.Sequence[int], | |||
| columns_list: typing.Sequence[container.DataFrame], *, source: typing.Any = None) -> container.DataFrame: | |||
| return base_utils.combine_columns(inputs, column_indices, columns_list, return_result=return_result, add_index_columns=add_index_columns) | |||
| @deprecate.function(message="use d3m.base.utils.combine_columns_metadata function instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def combine_columns_metadata(return_result: str, add_index_columns: bool, inputs_metadata: metadata_base.DataMetadata, column_indices: typing.Sequence[int], | |||
| columns_metadata_list: typing.Sequence[metadata_base.DataMetadata], *, source: typing.Any = None) -> metadata_base.DataMetadata: | |||
| return base_utils.combine_columns_metadata(inputs_metadata, column_indices, columns_metadata_list, return_result=return_result, add_index_columns=add_index_columns) | |||
| @deprecate.function(message="use DataMetadata.set_table_metadata method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def set_table_metadata(inputs_metadata: metadata_base.DataMetadata, *, at: metadata_base.Selector = (), source: typing.Any = None) -> metadata_base.DataMetadata: | |||
| return inputs_metadata.set_table_metadata(at=at) | |||
| @deprecate.function(message="use DataMetadata.get_column_index_from_column_name method instead") | |||
| def get_column_index_from_column_name(inputs_metadata: metadata_base.DataMetadata, column_name: str, *, at: metadata_base.Selector = ()) -> int: | |||
| return inputs_metadata.get_column_index_from_column_name(column_name, at=at) | |||
| @deprecate.function(message="use Dataset.get_relations_graph method instead") | |||
| def build_relation_graph(dataset: container.Dataset) -> typing.Dict[str, typing.List[typing.Tuple[str, bool, int, int, typing.Dict]]]: | |||
| return dataset.get_relations_graph() | |||
| @deprecate.function(message="use d3m.base.utils.get_tabular_resource function instead") | |||
| def get_tabular_resource(dataset: container.Dataset, resource_id: typing.Optional[str], *, | |||
| pick_entry_point: bool = True, pick_one: bool = True, has_hyperparameter: bool = True) -> typing.Tuple[str, container.DataFrame]: | |||
| return base_utils.get_tabular_resource(dataset, resource_id, pick_entry_point=pick_entry_point, pick_one=pick_one, has_hyperparameter=has_hyperparameter) | |||
| @deprecate.function(message="use d3m.base.utils.get_tabular_resource_metadata function instead") | |||
| def get_tabular_resource_metadata(dataset_metadata: metadata_base.DataMetadata, resource_id: typing.Optional[metadata_base.SelectorSegment], *, | |||
| pick_entry_point: bool = True, pick_one: bool = True) -> metadata_base.SelectorSegment: | |||
| return base_utils.get_tabular_resource_metadata(dataset_metadata, resource_id, pick_entry_point=pick_entry_point, pick_one=pick_one) | |||
| @deprecate.function(message="use Dataset.select_rows method instead") | |||
| @deprecate.arguments('source', message="argument ignored") | |||
| def cut_dataset(dataset: container.Dataset, row_indices_to_keep: typing.Mapping[str, typing.Sequence[int]], *, | |||
| source: typing.Any = None) -> container.Dataset: | |||
| return dataset.select_rows(row_indices_to_keep) | |||
| def parse_datetime(value: str, *, fuzzy: bool = True) -> typing.Optional[datetime.datetime]: | |||
| try: | |||
| return dateutil.parser.parse(value, default=DEFAULT_DATETIME, fuzzy=fuzzy) | |||
| except (ValueError, OverflowError, TypeError): | |||
| return None | |||
| def parse_datetime_to_float(value: str, *, fuzzy: bool = True) -> float: | |||
| try: | |||
| parsed = parse_datetime(value, fuzzy=fuzzy) | |||
| if parsed is None: | |||
| return numpy.nan | |||
| else: | |||
| return parsed.timestamp() | |||
| except (ValueError, OverflowError, TypeError): | |||
| return numpy.nan | |||