mirror of
https://github.com/apirrone/Open_Duck_Mini_Runtime.git
synced 2025-09-01 10:44:00 +00:00
rma
This commit is contained in:
parent
ce88afdeb6
commit
bfa0d0264f
2 changed files with 28 additions and 3 deletions
|
@ -12,7 +12,7 @@ class OnnxInfer:
|
|||
|
||||
def infer(self, inputs):
|
||||
if self.awd:
|
||||
outputs = self.ort_session.run(None, {"obs": [inputs]})
|
||||
outputs = self.ort_session.run(None, {self.input_name: [inputs]})
|
||||
return outputs[0][0]
|
||||
else:
|
||||
outputs = self.ort_session.run(
|
||||
|
|
|
@ -65,15 +65,29 @@ class RLWalk:
|
|||
replay_actions=None,
|
||||
zero_head=False,
|
||||
stand=False,
|
||||
rma=False,
|
||||
adaptation_module_path=None,
|
||||
):
|
||||
self.commands = commands
|
||||
self.pitch_bias = pitch_bias
|
||||
self.zero_head = zero_head
|
||||
self.stand = stand
|
||||
self.rma = rma
|
||||
self.adaptation_module_path = adaptation_module_path
|
||||
|
||||
self.num_obs = 56
|
||||
|
||||
self.onnx_model_path = onnx_model_path
|
||||
self.policy = OnnxInfer(self.onnx_model_path, awd=True)
|
||||
|
||||
if self.rma:
|
||||
self.adaptation_module = OnnxInfer(
|
||||
self.adaptation_module_path, "rma_history", awd=True
|
||||
)
|
||||
self.rma_obs_history_size = 50
|
||||
self.rma_obs_history = np.zeros((self.rma_obs_history_size, self.num_obs)).tolist()
|
||||
self.rma_decimation = 5 # 10 Hz if control_freq = 50 Hz
|
||||
|
||||
self.replay_obs = replay_obs
|
||||
if self.replay_obs is not None:
|
||||
self.replay_obs = pickle.load(open(self.replay_obs, "rb"))
|
||||
|
@ -83,8 +97,6 @@ class RLWalk:
|
|||
self.replay_actions = pickle.load(open(self.replay_actions, "rb"))
|
||||
self.replay_obs = None
|
||||
|
||||
self.num_obs = 56
|
||||
|
||||
# Control
|
||||
self.control_freq = control_freq
|
||||
self.pid = pid
|
||||
|
@ -267,6 +279,7 @@ class RLWalk:
|
|||
# voltages = []
|
||||
i = 0
|
||||
start = time.time()
|
||||
latent = None
|
||||
try:
|
||||
print("Starting")
|
||||
while True:
|
||||
|
@ -282,6 +295,14 @@ class RLWalk:
|
|||
else:
|
||||
break
|
||||
|
||||
if self.rma:
|
||||
self.rma_obs_history.append(obs)
|
||||
self.rma_obs_history = self.rma_obs_history[-self.rma_obs_history_size:]
|
||||
|
||||
if i % self.rma_decimation == 0 or latent is None:
|
||||
latent = self.adaptation_module.infer(np.array(self.rma_obs_history).flatten())
|
||||
obs = np.concatenate([obs, latent])
|
||||
|
||||
obs = np.clip(obs, -100, 100)
|
||||
|
||||
if self.replay_actions is None:
|
||||
|
@ -370,6 +391,8 @@ if __name__ == "__main__":
|
|||
)
|
||||
parser.add_argument("--replay_obs", type=str, required=False, default=None)
|
||||
parser.add_argument("--replay_actions", type=str, required=False, default=None)
|
||||
parser.add_argument("--rma", action="store_true", default=False)
|
||||
parser.add_argument("--adaptation_module_path", type=str, required=False)
|
||||
args = parser.parse_args()
|
||||
pid = [args.p, args.i, args.d]
|
||||
|
||||
|
@ -386,6 +409,8 @@ if __name__ == "__main__":
|
|||
replay_actions=args.replay_actions,
|
||||
zero_head=args.zero_head,
|
||||
stand=args.stand,
|
||||
rma=args.rma,
|
||||
adaptation_module_path=args.adaptation_module_path,
|
||||
)
|
||||
print("Done instantiating RLWalk")
|
||||
# rl_walk.start()
|
||||
|
|
Loading…
Reference in a new issue