Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse jaxified logp when sampling via jax #7681

Merged
merged 1 commit into from
Feb 26, 2025

Conversation

nataziel
Copy link
Contributor

@nataziel nataziel commented Feb 14, 2025

reuse jaxified logp when sampling via jax

Description

#7610 added logic to handle passing a pre-jaxified logp function into the blackjax/numpyro samplers, but missed actually passing the jaxified logp that is computed in sample_jax_nuts

Checklist

  • Checked that the pre-commit linting/style checks pass
  • Included tests that prove the fix is effective or that the new feature works
  • Added necessary documentation (docstrings and/or example notebooks)
  • If you are a pro: each commit corresponds to a [relevant logical change]

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7681.org.readthedocs.build/en/7681/

@nataziel nataziel changed the title reuse jaxified logp times when sampling via jax reuse jaxified logp when sampling via jax Feb 14, 2025
Copy link

codecov bot commented Feb 14, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.64%. Comparing base (358b825) to head (e63a8a2).
Report is 7 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7681      +/-   ##
==========================================
- Coverage   92.70%   92.64%   -0.06%     
==========================================
  Files         107      107              
  Lines       18391    18324      -67     
==========================================
- Hits        17050    16977      -73     
- Misses       1341     1347       +6     
Files with missing lines Coverage Δ
pymc/sampling/jax.py 94.11% <ø> (-0.91%) ⬇️

... and 1 file with indirect coverage changes

@nataziel
Copy link
Contributor Author

@ricardoV94 not sure if you've seen this, but it's a super tiny change that we should have included with #7610 that I just missed

@twiecki
Copy link
Member

twiecki commented Feb 26, 2025

Claude Code:
Thanks for this PR! I see this is a small but important maintenance fix that adds the parameter to the sampler function call, which was computed earlier but not passed through. This ensures we reuse the already jaxified logp function when sampling via JAX instead of recomputing it unnecessarily.\n\nThis change looks good and should improve efficiency by avoiding duplicate computation of the logp function. Appears to be a small oversight in PR #7610 that this properly fixes.

@ricardoV94
Copy link
Member

Claude Code: Thanks for this PR! I see this is a small but important maintenance fix that adds the parameter to the sampler function call, which was computed earlier but not passed through. This ensures we reuse the already jaxified logp function when sampling via JAX instead of recomputing it unnecessarily.\n\nThis change looks good and should improve efficiency by avoiding duplicate computation of the logp function. Appears to be a small oversight in PR #7610 that this properly fixes.

Useless

@ricardoV94
Copy link
Member

Thanks @nataziel

@ricardoV94 ricardoV94 merged commit 2fbc8a9 into pymc-devs:main Feb 26, 2025
26 checks passed
@ricardoV94 ricardoV94 changed the title reuse jaxified logp when sampling via jax Reuse jaxified logp when sampling via jax Feb 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants