Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mean kl is always=0 #10

Open
xzhang2523 opened this issue Feb 21, 2023 · 5 comments
Open

mean kl is always=0 #10

xzhang2523 opened this issue Feb 21, 2023 · 5 comments

Comments

@xzhang2523
Copy link

hi, I notice that in your code, mean_kl always=0
constraint_grad = flat_grad(constraint_loss, self.policy.parameters(), retain_graph=True) # (b)

    mean_kl = mean_kl_first_fixed(actions_dists, actions_dists)
    Fvp_fun = get_Hvp_fun(mean_kl, self.policy.parameters())

what is the meaning of a gradient of a constant?

@ajlangley
Copy link
Owner

How do you know that it is always 0? I have not looked at this code in a long time, but I'll do my best to help.

@xzhang2523
Copy link
Author

image

in line 164, KL of two same distribution is 0.

and the imp-sampling is always 1.

Best,
Xiaoyuan

@ajlangley
Copy link
Owner

The first line you highlighted is not a bug. If you notice, the second term is detached, so that we are taking the gradient of the importance sampling ratio w.r.t. the new policy. The second line you boxed is indeed a bug. One of those terms should be detached (whichever one represents the old policy). I believe the first input should be detached, the and the second one should require a gradient.

@xzhang2523
Copy link
Author

Dear aj,

Actually, I have print the value of "imp_sampling" while debugging. It is actually, always 1.

@ajlangley
Copy link
Owner

ajlangley commented Feb 23, 2023

Right, but does it have `require_grad=True'?

The intention is that it will always be 1 when you print it out, because what is really happening is that we are just setting up a symbolic expression involving the old policy, which we just used to collect the data, and the new policy, which we are going to find via gradient descent, and we are using the old policy as a starting point.

What I mean is, we are really setting up a function J(x) and taking the gradient at x = old_policy.

Maybe you can try printing out the sampling ratio between the old policy and the new policy after updating. If they are still 1, that is indeed a bug.

Does this make sense?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants