mask head

This commit is contained in:
apirrone 2025-08-08 18:33:15 +02:00
parent 6c514ed945
commit 9ff8fd3834

View file

@ -6,6 +6,7 @@ from mini_bdx_runtime.rustypot_position_hwi import HWI
from mini_bdx_runtime.onnx_infer import OnnxInfer
from mini_bdx_runtime.raw_imu import Imu
# from mini_bdx_runtime.imu import Imu
from mini_bdx_runtime.poly_reference_motion import PolyReferenceMotion
from mini_bdx_runtime.xbox_controller import XBoxController
@ -20,6 +21,7 @@ from mini_bdx_runtime.duck_config import DuckConfig
import os
HOME_DIR = os.path.expanduser("~")
MASK_HEAD = True
class RLWalk:
@ -37,7 +39,6 @@ class RLWalk:
replay_obs=None,
cutoff_frequency=None,
):
self.duck_config = DuckConfig(config_json_path=duck_config_path)
self.commands = commands
@ -123,10 +124,8 @@ class RLWalk:
self.antennas = Antennas()
def get_obs(self):
# imu_data = self.imu.get_data(as_mat=True)
# raw
# raw
imu_data = self.imu.get_data()
if imu_data is None:
print("IMU data is None, skipping observation retrieval")
@ -160,6 +159,16 @@ class RLWalk:
print(f"ERROR len(dof_vel) != {self.num_dofs}")
return None
_init_pos = self.init_pos
_motor_targets = self.motor_targets
if MASK_HEAD:
dof_pos = np.concatenate([dof_pos[:5], dof_pos[9:]])
dof_vel = np.concatenate([dof_vel[:5], dof_vel[9:]])
_init_pos = np.concatenate([self.init_pos[:5], self.init_pos[9:]])
_motor_targets = np.concatenate(
[self.motor_targets[:5], self.motor_targets[9:]]
)
cmds = self.last_commands
feet_contacts = self.feet_contacts.get()
@ -169,13 +178,13 @@ class RLWalk:
accelero,
gyro,
# gravity,
cmds,
dof_pos - self.init_pos,
cmds[:3],
dof_pos - _init_pos,
dof_vel * self.dof_vel_scale,
self.last_action,
self.last_last_action,
self.last_last_last_action,
self.motor_targets,
_motor_targets,
feet_contacts,
self.imitation_phase,
]
@ -197,7 +206,6 @@ class RLWalk:
time.sleep(2)
def get_phase_frequency_factor(self, x_velocity):
max_phase_frequency = 1.2
min_phase_frequency = 1.0
@ -322,8 +330,8 @@ class RLWalk:
self.prev_motor_targets = self.motor_targets.copy()
# head_motor_targets = self.last_commands[3:]# + self.motor_targets[5:9]
# self.motor_targets[5:9] = head_motor_targets
head_motor_targets = self.last_commands[3:]
self.motor_targets[5:9] = head_motor_targets
action_dict = make_action_dict(
self.motor_targets, list(self.hwi.joints.keys())