This commit is contained in:
apirrone 2025-02-01 23:09:49 +01:00
parent ce88afdeb6
commit bfa0d0264f
2 changed files with 28 additions and 3 deletions

View file

@ -12,7 +12,7 @@ class OnnxInfer:
def infer(self, inputs):
if self.awd:
outputs = self.ort_session.run(None, {"obs": [inputs]})
outputs = self.ort_session.run(None, {self.input_name: [inputs]})
return outputs[0][0]
else:
outputs = self.ort_session.run(

View file

@ -65,15 +65,29 @@ class RLWalk:
replay_actions=None,
zero_head=False,
stand=False,
rma=False,
adaptation_module_path=None,
):
self.commands = commands
self.pitch_bias = pitch_bias
self.zero_head = zero_head
self.stand = stand
self.rma = rma
self.adaptation_module_path = adaptation_module_path
self.num_obs = 56
self.onnx_model_path = onnx_model_path
self.policy = OnnxInfer(self.onnx_model_path, awd=True)
if self.rma:
self.adaptation_module = OnnxInfer(
self.adaptation_module_path, "rma_history", awd=True
)
self.rma_obs_history_size = 50
self.rma_obs_history = np.zeros((self.rma_obs_history_size, self.num_obs)).tolist()
self.rma_decimation = 5 # 10 Hz if control_freq = 50 Hz
self.replay_obs = replay_obs
if self.replay_obs is not None:
self.replay_obs = pickle.load(open(self.replay_obs, "rb"))
@ -83,8 +97,6 @@ class RLWalk:
self.replay_actions = pickle.load(open(self.replay_actions, "rb"))
self.replay_obs = None
self.num_obs = 56
# Control
self.control_freq = control_freq
self.pid = pid
@ -267,6 +279,7 @@ class RLWalk:
# voltages = []
i = 0
start = time.time()
latent = None
try:
print("Starting")
while True:
@ -282,6 +295,14 @@ class RLWalk:
else:
break
if self.rma:
self.rma_obs_history.append(obs)
self.rma_obs_history = self.rma_obs_history[-self.rma_obs_history_size:]
if i % self.rma_decimation == 0 or latent is None:
latent = self.adaptation_module.infer(np.array(self.rma_obs_history).flatten())
obs = np.concatenate([obs, latent])
obs = np.clip(obs, -100, 100)
if self.replay_actions is None:
@ -370,6 +391,8 @@ if __name__ == "__main__":
)
parser.add_argument("--replay_obs", type=str, required=False, default=None)
parser.add_argument("--replay_actions", type=str, required=False, default=None)
parser.add_argument("--rma", action="store_true", default=False)
parser.add_argument("--adaptation_module_path", type=str, required=False)
args = parser.parse_args()
pid = [args.p, args.i, args.d]
@ -386,6 +409,8 @@ if __name__ == "__main__":
replay_actions=args.replay_actions,
zero_head=args.zero_head,
stand=args.stand,
rma=args.rma,
adaptation_module_path=args.adaptation_module_path,
)
print("Done instantiating RLWalk")
# rl_walk.start()