Warn if using offline data buffer without Agent wrapper#1771
Warn if using offline data buffer without Agent wrapper#1771QuantuMope wants to merge 3 commits intopytorchfrom
Conversation
alf/algorithms/sac_algorithm.py
Outdated
| action, action_distribution): | ||
|
|
||
| if isinstance(rollout_info, BasicRolloutInfo): | ||
| rollout_info = rollout_info.rl |
There was a problem hiding this comment.
This should be put outside of this function. The general principle is, the algorithm should always receive what it's supposed to receive. In this case, this means that the rollout_info passed in should already be SacInfo.
There was a problem hiding this comment.
This should be put outside of this function. The general principle is, the algorithm should always receive what it's supposed to receive. In this case, this means that the rollout_info passed in should already be SacInfo.
+1
alf/algorithms/sac_algorithm.py
Outdated
| if isinstance(rollout_info, BasicRolloutInfo): | ||
| rollout_info = rollout_info.rl | ||
| state: SacCriticState, | ||
| rollout_info: SacInfo | BasicRLInfo, action, |
There was a problem hiding this comment.
Still should always be SacInfo? If it's BasicRLInfo, the algorithm will crash.
There was a problem hiding this comment.
Offline buffer data is stored as BasicRLInfo which comprises of just (s,a,r) data.
There was a problem hiding this comment.
I guess, I could convert BasicRLInfo into SacInfo with some fields empty? Not sure which is a better design. Lmk which one you think is cleaner and I can change.
There was a problem hiding this comment.
Offline buffer data is stored as
BasicRLInfowhich comprises of just (s,a,r) data.
If you look at SAC's train_step(), it will get access to rollout_info.repr. This means that SAC is currently incompatible with offline training.
There was a problem hiding this comment.
I'm using a frozen encoder so not training a repr.
There was a problem hiding this comment.
Wait, it seems that repr is stored in BasicRolloutInfo. Not sure how this code was running then. I'll take a look
There was a problem hiding this comment.
It's using elastic_namedtuple so any missing field returns (). Anyway, a little weird but it works.
There was a problem hiding this comment.
Ideally, we should not include BasicRLInfo here as it could confuse the pure sac users. The better alternative might be comply with the Agent assumption and possibly extend it.
There was a problem hiding this comment.
Removed typehints. Also added a warning message advising users to use Agent. Before the code would simply crash due to interface conflict.
| logging.WARNING, | ||
| "Detected offline buffer training without Agent wrapper. " | ||
| "For best compatibility, it is advised to use the Agent wrapper.", | ||
| n=1) |
There was a problem hiding this comment.
This warning won't work. When using Agent, we still get rollout_info as BasicRolloutInfo.
There was a problem hiding this comment.
It will. When using Agent, it properly feeds the nested BasicRLInfo to this function instead.
There was a problem hiding this comment.
In other words, there was never a bug with hybrid RL training, just that it was never meant to be used without using Agent.
There was a problem hiding this comment.
It will. When using
Agent, it properly feeds the nestedBasicRLInfoto this function instead.
you're right. I didn't know Agent overwrites this function itself.
This PR now warns users if they are using the offline data buffer without the
Agentwrapper.