|
- import os
-
- import cv2
- import h5py
- import numpy as np
- import torch
- from torch.utils.data import DataLoader
-
- from constants import IMG_H, IMG_W
-
-
- class EpisodicDataset(torch.utils.data.Dataset):
- def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats):
- super(EpisodicDataset).__init__()
- self.episode_ids = episode_ids
- self.dataset_dir = dataset_dir
- self.camera_names = camera_names
- self.norm_stats = norm_stats
- self.is_sim = None
- self.__getitem__(0) # initialize self.is_sim
-
- def __len__(self):
- return len(self.episode_ids)
-
- def __getitem__(self, index):
- sample_full_episode = False # hardcode
-
- episode_id = self.episode_ids[index]
- dataset_path = os.path.join(self.dataset_dir, f"episode_{episode_id}.hdf5")
- with h5py.File(dataset_path, "r") as root:
- is_sim = False # hardcode
- original_action_shape = root["/action"].shape
- episode_len = original_action_shape[0]
- if sample_full_episode:
- start_ts = 0
- else:
- start_ts = np.random.choice(episode_len)
- # get observation at start_ts only
- qpos = root["/observations/qpos"][start_ts]
- image_dict = dict()
- for cam_name in self.camera_names:
- raw_img = root[f"/observations/images/{cam_name}"][start_ts]
- resize_img = cv2.resize(
- np.array(raw_img), (IMG_W, IMG_H), interpolation=cv2.INTER_LINEAR
- )
- image_dict[cam_name] = resize_img
- # get all actions after and including start_ts
- action = root["/action"][start_ts:]
- action_len = episode_len - start_ts
-
- self.is_sim = is_sim
- padded_action = np.zeros(original_action_shape, dtype=np.float32)
- padded_action[:action_len] = action
- is_pad = np.zeros(episode_len)
- is_pad[action_len:] = 1
-
- # new axis for different cameras
- all_cam_images = []
- for cam_name in self.camera_names:
- all_cam_images.append(image_dict[cam_name])
- all_cam_images = np.stack(all_cam_images, axis=0)
-
- # construct observations
- image_data = torch.from_numpy(all_cam_images)
- qpos_data = torch.tensor(qpos, dtype=torch.float32)
- action_data = torch.tensor(padded_action, dtype=torch.float32)
- is_pad = torch.from_numpy(is_pad).bool()
-
- # channel last
- image_data = torch.einsum("k h w c -> k c h w", image_data)
-
- # normalize image and change dtype to float
- image_data = image_data / 255.0
- action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"]
- qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"]
-
- return image_data, qpos_data, action_data, is_pad
-
-
- def get_norm_stats(dataset_dir, num_episodes):
- all_qpos_data = []
- all_action_data = []
- for episode_idx in range(num_episodes):
- dataset_path = os.path.join(dataset_dir, f"episode_{episode_idx}.hdf5")
- with h5py.File(dataset_path, "r") as root:
- qpos = root["/observations/qpos"][()]
- action = root["/action"][()]
- all_qpos_data.append(torch.from_numpy(qpos))
- all_action_data.append(torch.from_numpy(action))
- all_qpos_data = torch.stack(all_qpos_data)
- all_action_data = torch.stack(all_action_data)
-
- # normalize action data
- action_mean = all_action_data.mean(dim=[0, 1], keepdim=True)
- action_std = all_action_data.std(dim=[0, 1], keepdim=True)
- action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
-
- # normalize qpos data
- qpos_mean = all_qpos_data.mean(dim=[0, 1], keepdim=True)
- qpos_std = all_qpos_data.std(dim=[0, 1], keepdim=True)
- qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
-
- stats = {
- "action_mean": action_mean.numpy().squeeze(),
- "action_std": action_std.numpy().squeeze(),
- "qpos_mean": qpos_mean.numpy().squeeze(),
- "qpos_std": qpos_std.numpy().squeeze(),
- "example_qpos": qpos,
- }
-
- return stats
-
-
- def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val):
- print(f"\nData from: {dataset_dir}\n")
- # obtain train test split
- train_ratio = 0.8
- shuffled_indices = np.random.permutation(num_episodes)
- train_indices = shuffled_indices[: int(train_ratio * num_episodes)]
- val_indices = shuffled_indices[int(train_ratio * num_episodes) :]
-
- # obtain normalization stats for qpos and action
- norm_stats = get_norm_stats(dataset_dir, num_episodes)
-
- # construct dataset and dataloader
- train_dataset = EpisodicDataset(train_indices, dataset_dir, camera_names, norm_stats)
- val_dataset = EpisodicDataset(val_indices, dataset_dir, camera_names, norm_stats)
- train_dataloader = DataLoader(
- train_dataset,
- batch_size=batch_size_train,
- shuffle=True,
- pin_memory=True,
- num_workers=1,
- prefetch_factor=1,
- )
- val_dataloader = DataLoader(
- val_dataset,
- batch_size=batch_size_val,
- shuffle=True,
- pin_memory=True,
- num_workers=1,
- prefetch_factor=1,
- )
-
- return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim
-
-
- def compute_dict_mean(epoch_dicts):
- result = {k: None for k in epoch_dicts[0]}
- num_items = len(epoch_dicts)
- for k in result:
- value_sum = 0
- for epoch_dict in epoch_dicts:
- value_sum += epoch_dict[k]
- result[k] = value_sum / num_items
- return result
-
-
- def detach_dict(d):
- new_d = dict()
- for k, v in d.items():
- new_d[k] = v.detach()
- return new_d
-
-
- def set_seed(seed):
- torch.manual_seed(seed)
- np.random.seed(seed)
|