You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

record.py 9.9 kB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import h5py
  2. import os
  3. import datetime
  4. from dora import Node
  5. import numpy as np
  6. from convert import (
  7. convert_euler_to_rotation_matrix,
  8. compute_ortho6d_from_rotation_matrix,
  9. )
  10. STATE_VEC_IDX_MAPPING = {
  11. # [0, 10): right arm joint positions
  12. **{"arm_joint_{}_pos".format(i): i for i in range(10)},
  13. **{"right_arm_joint_{}_pos".format(i): i for i in range(10)},
  14. # [10, 15): right gripper joint positions
  15. **{"gripper_joint_{}_pos".format(i): i + 10 for i in range(5)},
  16. **{"right_gripper_joint_{}_pos".format(i): i + 10 for i in range(5)},
  17. "gripper_open": 10, # alias of right_gripper_joint_0_pos
  18. "right_gripper_open": 10,
  19. # [15, 25): right arm joint velocities
  20. **{"arm_joint_{}_vel".format(i): i + 15 for i in range(10)},
  21. **{"right_arm_joint_{}_vel".format(i): i + 15 for i in range(10)},
  22. # [25, 30): right gripper joint velocities
  23. **{"gripper_joint_{}_vel".format(i): i + 25 for i in range(5)},
  24. **{"right_gripper_joint_{}_vel".format(i): i + 25 for i in range(5)},
  25. "gripper_open_vel": 25, # alias of right_gripper_joint_0_vel
  26. "right_gripper_open_vel": 25,
  27. # [30, 33): right end effector positions
  28. "eef_pos_x": 30,
  29. "right_eef_pos_x": 30,
  30. "eef_pos_y": 31,
  31. "right_eef_pos_y": 31,
  32. "eef_pos_z": 32,
  33. "right_eef_pos_z": 32,
  34. # [33, 39): right end effector 6D pose
  35. "eef_angle_0": 33,
  36. "right_eef_angle_0": 33,
  37. "eef_angle_1": 34,
  38. "right_eef_angle_1": 34,
  39. "eef_angle_2": 35,
  40. "right_eef_angle_2": 35,
  41. "eef_angle_3": 36,
  42. "right_eef_angle_3": 36,
  43. "eef_angle_4": 37,
  44. "right_eef_angle_4": 37,
  45. "eef_angle_5": 38,
  46. "right_eef_angle_5": 38,
  47. # [39, 42): right end effector velocities
  48. "eef_vel_x": 39,
  49. "right_eef_vel_x": 39,
  50. "eef_vel_y": 40,
  51. "right_eef_vel_y": 40,
  52. "eef_vel_z": 41,
  53. "right_eef_vel_z": 41,
  54. # [42, 45): right end effector angular velocities
  55. "eef_angular_vel_roll": 42,
  56. "right_eef_angular_vel_roll": 42,
  57. "eef_angular_vel_pitch": 43,
  58. "right_eef_angular_vel_pitch": 43,
  59. "eef_angular_vel_yaw": 44,
  60. "right_eef_angular_vel_yaw": 44,
  61. # [45, 50): reserved
  62. # [50, 60): left arm joint positions
  63. **{"left_arm_joint_{}_pos".format(i): i + 50 for i in range(10)},
  64. # [60, 65): left gripper joint positions
  65. **{"left_gripper_joint_{}_pos".format(i): i + 60 for i in range(5)},
  66. "left_gripper_open": 60, # alias of left_gripper_joint_0_pos
  67. # [65, 75): left arm joint velocities
  68. **{"left_arm_joint_{}_vel".format(i): i + 65 for i in range(10)},
  69. # [75, 80): left gripper joint velocities
  70. **{"left_gripper_joint_{}_vel".format(i): i + 75 for i in range(5)},
  71. "left_gripper_open_vel": 75, # alias of left_gripper_joint_0_vel
  72. # [80, 83): left end effector positions
  73. "left_eef_pos_x": 80,
  74. "left_eef_pos_y": 81,
  75. "left_eef_pos_z": 82,
  76. # [83, 89): left end effector 6D pose
  77. "left_eef_angle_0": 83,
  78. "left_eef_angle_1": 84,
  79. "left_eef_angle_2": 85,
  80. "left_eef_angle_3": 86,
  81. "left_eef_angle_4": 87,
  82. "left_eef_angle_5": 88,
  83. # [89, 92): left end effector velocities
  84. "left_eef_vel_x": 89,
  85. "left_eef_vel_y": 90,
  86. "left_eef_vel_z": 91,
  87. # [92, 95): left end effector angular velocities
  88. "left_eef_angular_vel_roll": 92,
  89. "left_eef_angular_vel_pitch": 93,
  90. "left_eef_angular_vel_yaw": 94,
  91. # [95, 100): reserved
  92. # [100, 102): base linear velocities
  93. "base_vel_x": 100,
  94. "base_vel_y": 101,
  95. # [102, 103): base angular velocities
  96. "base_angular_vel": 102,
  97. # [103, 128): reserved
  98. }
  99. STATE_VEC_LEN = 128
  100. now = datetime.datetime.now()
  101. DATA_DIR = "/home/agilex/Desktop/" + now.strftime("%Y.%m.%d.%H.%M")
  102. os.makedirs(DATA_DIR, exist_ok=True)
  103. ## Make data dir if it does not exist
  104. if not os.path.exists(DATA_DIR):
  105. os.makedirs(DATA_DIR)
  106. def save_data(data_dict, dataset_path, data_size):
  107. with h5py.File(dataset_path + ".hdf5", "w", rdcc_nbytes=1024**2 * 2) as root:
  108. root.attrs["sim"] = False
  109. root.attrs["compress"] = False
  110. obs = root.create_group("observations")
  111. variable_length = h5py.vlen_dtype(np.dtype("uint8"))
  112. image = obs.create_group("images")
  113. _ = image.create_dataset(
  114. "cam_high",
  115. (data_size,),
  116. dtype=variable_length,
  117. )
  118. _ = image.create_dataset(
  119. "cam_left_wrist",
  120. (data_size,),
  121. dtype=variable_length,
  122. )
  123. _ = image.create_dataset(
  124. "cam_right_wrist",
  125. (data_size,),
  126. dtype=variable_length,
  127. )
  128. _ = obs.create_dataset("qpos", (data_size, 128))
  129. _ = root.create_dataset("action", (data_size, 128))
  130. # data_dict write into h5py.File
  131. for name, array in data_dict.items():
  132. print(name)
  133. if "images" in name:
  134. image[name][...] = array
  135. else:
  136. root[name][...] = array
  137. data_dict = {
  138. "/observations/qpos": [],
  139. "/observations/images/cam_high": [],
  140. "/observations/images/cam_left_wrist": [],
  141. "/observations/images/cam_right_wrist": [],
  142. "/action": [],
  143. }
  144. node = Node()
  145. LEAD_CAMERA = "/observations/images/cam_high"
  146. tmp_dict = {}
  147. i = 0
  148. start = False
  149. for event in node:
  150. if event["type"] == "INPUT":
  151. if "save" in event["id"]:
  152. char = event["value"][0].as_py()
  153. if char == "p":
  154. if start == False:
  155. continue
  156. save_data(
  157. data_dict,
  158. f"{DATA_DIR}/episode_{i}",
  159. len(data_dict["/observations/qpos"]),
  160. )
  161. # Reset dict
  162. data_dict = {
  163. "/observations/qpos": [],
  164. "/observations/images/cam_high": [],
  165. "/observations/images/cam_left_wrist": [],
  166. "/observations/images/cam_right_wrist": [],
  167. "/action": [],
  168. }
  169. i += 1
  170. start = False
  171. elif char == "s":
  172. start = True
  173. elif "image" in event["id"]:
  174. tmp_dict[event["id"]] = event["value"].to_numpy()
  175. elif "qpos" in event["id"]:
  176. tmp_dict[event["id"]] = event["value"].to_numpy()
  177. elif "pose" in event["id"]:
  178. value = event["value"].to_numpy()
  179. euler = value[None, 3:6] # Add batch dimension
  180. rotmat = convert_euler_to_rotation_matrix(euler)
  181. ortho6d = compute_ortho6d_from_rotation_matrix(rotmat)[0]
  182. values = np.array(
  183. [
  184. value[0],
  185. value[1],
  186. value[2],
  187. ortho6d[0],
  188. ortho6d[1],
  189. ortho6d[2],
  190. ortho6d[3],
  191. ortho6d[4],
  192. ortho6d[5],
  193. ]
  194. )
  195. tmp_dict[event["id"]] = values
  196. elif "base_vel" in event["id"]:
  197. tmp_dict[event["id"]] = event["value"].to_numpy()
  198. # Check if tmp dict is full
  199. if len(tmp_dict) != 7:
  200. continue
  201. elif event["id"] == LEAD_CAMERA and start == True:
  202. values = np.concatenate(
  203. [
  204. tmp_dict["/observations/qpos_left"],
  205. tmp_dict["/observations/qpos_right"],
  206. tmp_dict["/observations/pose_left"],
  207. tmp_dict["/observations/pose_right"],
  208. # tmp_dict["/observations/base_vel"],
  209. ]
  210. )
  211. UNI_STATE_INDICES = (
  212. [STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(6)]
  213. + [STATE_VEC_IDX_MAPPING["left_gripper_open"]]
  214. + [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)]
  215. + [STATE_VEC_IDX_MAPPING["right_gripper_open"]]
  216. + [STATE_VEC_IDX_MAPPING["left_eef_pos_x"]]
  217. + [STATE_VEC_IDX_MAPPING["left_eef_pos_y"]]
  218. + [STATE_VEC_IDX_MAPPING["left_eef_pos_z"]]
  219. + [STATE_VEC_IDX_MAPPING["left_eef_angle_0"]]
  220. + [STATE_VEC_IDX_MAPPING["left_eef_angle_1"]]
  221. + [STATE_VEC_IDX_MAPPING["left_eef_angle_2"]]
  222. + [STATE_VEC_IDX_MAPPING["left_eef_angle_3"]]
  223. + [STATE_VEC_IDX_MAPPING["left_eef_angle_4"]]
  224. + [STATE_VEC_IDX_MAPPING["left_eef_angle_5"]]
  225. + [STATE_VEC_IDX_MAPPING["right_eef_pos_x"]]
  226. + [STATE_VEC_IDX_MAPPING["right_eef_pos_y"]]
  227. + [STATE_VEC_IDX_MAPPING["right_eef_pos_z"]]
  228. + [STATE_VEC_IDX_MAPPING["right_eef_angle_0"]]
  229. + [STATE_VEC_IDX_MAPPING["right_eef_angle_1"]]
  230. + [STATE_VEC_IDX_MAPPING["right_eef_angle_2"]]
  231. + [STATE_VEC_IDX_MAPPING["right_eef_angle_3"]]
  232. + [STATE_VEC_IDX_MAPPING["right_eef_angle_4"]]
  233. + [STATE_VEC_IDX_MAPPING["right_eef_angle_5"]]
  234. # + [STATE_VEC_IDX_MAPPING["base_vel_x"]]
  235. # + [STATE_VEC_IDX_MAPPING["base_angular_vel"]],
  236. )
  237. universal_vec = np.zeros(STATE_VEC_LEN)
  238. universal_vec[UNI_STATE_INDICES] = values
  239. data_dict["/observations/qpos"].append(universal_vec)
  240. # We reproduce obs and action
  241. data_dict["/action"].append(universal_vec)
  242. data_dict["/observations/images/cam_high"].append(
  243. tmp_dict["/observations/images/cam_high"]
  244. )
  245. data_dict["/observations/images/cam_left_wrist"].append(
  246. tmp_dict["/observations/images/cam_left_wrist"]
  247. )
  248. data_dict["/observations/images/cam_right_wrist"].append(
  249. tmp_dict["/observations/images/cam_right_wrist"]
  250. )