Hi, I am trying Sampled EfficientZero for a large, discrete action space, and I ran into ~3 bugs in ptree_sez.py. I believe the following code has some issues:
if self.legal_actions is not None:
# first use the self.legal_actions to exclude the illegal actions
policy_tmp = [0. for _ in range(self.action_space_size)]
for index, legal_action in enumerate(self.legal_actions):
policy_tmp[legal_action] = policy_logits[index]
policy_logits = policy_tmp
# then empty the self.legal_actions
self.legal_actions = []
prob = torch.softmax(torch.tensor(policy_logits), dim=-1)
sampled_actions = torch.multinomial(prob, self.num_of_sampled_actions, replacement=False)
for action_index in range(self.num_of_sampled_actions):
self.children[Action(sampled_actions[action_index].detach().cpu().numpy())] = Node( ...
-
By setting policy_tmp to 0, then taking the softmax, illegal moves still have non-zero probability. I just set the default to -1e9 as a hack, and it seems to work now, but it's probably more efficient to extract the list of legal moves before taking the softmax.
-
The line policy_tmp[legal_action] = policy_logits[index] is copying the wrong policy logits. I think it should be policy_tmp[legal_action] = policy_logits[legal_action].
-
I think instead of for action_index in range(self.num_of_sampled_actions):, it should be something like for action_index in range(len(sampled_actions)): This is because there could be fewer legal moves than sampled actions (K), and torch.multinomial is used with replacement=False so sampled_actions[action_index] will have an index out of bounds issue.
Sorry for not just submitting PRs for these bugs - I am not confident in the best way to fix it. I'm still working my way through my own fix so I may run into more related bugs later. I haven't verified that the rest of MCTS works, and I haven't checked the "ctree" equivalent, either.
Hi, I am trying Sampled EfficientZero for a large, discrete action space, and I ran into ~3 bugs in ptree_sez.py. I believe the following code has some issues:
By setting policy_tmp to
0, then taking the softmax, illegal moves still have non-zero probability. I just set the default to -1e9 as a hack, and it seems to work now, but it's probably more efficient to extract the list of legal moves before taking the softmax.The line
policy_tmp[legal_action] = policy_logits[index]is copying the wrong policy logits. I think it should bepolicy_tmp[legal_action] = policy_logits[legal_action].I think instead of
for action_index in range(self.num_of_sampled_actions):, it should be something likefor action_index in range(len(sampled_actions)):This is because there could be fewer legal moves than sampled actions (K), and torch.multinomial is used withreplacement=Falsesosampled_actions[action_index]will have an index out of bounds issue.Sorry for not just submitting PRs for these bugs - I am not confident in the best way to fix it. I'm still working my way through my own fix so I may run into more related bugs later. I haven't verified that the rest of MCTS works, and I haven't checked the "ctree" equivalent, either.