Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for pymc native samplers to resume sampling from ZarrTrace #7687

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

lucianopaz
Copy link
Contributor

@lucianopaz lucianopaz commented Feb 21, 2025

Description

Big PR approaching! This finishes adding the ability of pymc native step methods to resume sampling from an existing trace (as long as it's a ZarrTrace!). This means that you can now continue tuning or sampling from a pre-existing sample run. For example

with model:
    # First tuning run
    pm.sample(tune=400, draws=0, trace=trace)

    # Do whatever to decide if you want to continue tuning   
    pm.sample(tune=800, draws=0, trace=trace)

    # Switch to sampling
    pm.sample(tune=800, draws=1000, trace=trace)

Another thing is that the chunks_per_draw from ZarrTrace along with its persistent storage backends (like ZipStore or DirectoryStore) makes the sampling store the results and final sampling state periodically, so in case of a crash during sampling, you can use the existing store to load the trace using ZarrTrace.from_store and then resume sampling from there.

The only thing that I haven't tested for yet is to add an Op that makes pm.sample crash to see if I can reload the partial results from the store and resume sampling. @ricardoV94 gave me some pointers to that, but I won't be working on this for the rest of the month and I thought it best to open a draft PR to kick off any discussion you have or collect feedback

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7687.org.readthedocs.build/en/7687/

@lucianopaz lucianopaz added enhancements trace-backend Traces and ArviZ stuff major Include in major changes release notes section labels Feb 21, 2025
@lucianopaz lucianopaz changed the title Zarr continue Allow for pymc native samplers to resume sampling from ZarrTrace Feb 21, 2025
vars=trace_vars,
test_point=initial_point,
)
except TraceAlreadyInitialized:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just InitializedTrace? Seems a little verbose!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds fine to me, it's an internal thing

Comment on lines +1161 to +1169
if isinstance(trace, ZarrChain):
progress_manager.set_initial_state(*trace.completed_draws_and_divergences())
progress_manager._progress.update(
progress_manager.tasks[i],
draws=progress_manager.completed_draws
if progress_manager.combined_progress
else progress_manager.draws,
divergences=progress_manager.divergences,
refresh=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't like this abstraction leaking elsewhere, just provide a default to the Ndarray backend that makes it work for either method. In that case I suppose start everything at zero

if isinstance(trace, ZarrChain):
trace.link_stepper(step)
stored_draw_idx = trace._sampling_state.draw_idx[chain]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here all this logic including the old link_stepper can have a sensible default in the base trace class so you don't need to worry about what kind of trace you have here. Just make link_stepper a no op and stored_draw_idx to be zero by default?

Comment on lines +201 to +211
if stored_draw_idx > 0:
if stored_sampling_state is not None:
self._step_method.sampling_state = stored_sampling_state
else:
raise RuntimeError(
"Cannot use the supplied ZarrTrace to restart sampling because "
"it has no sampling_state information stored. You will have to "
"resample from scratch."
)
draw = stored_draw_idx
self._write_point(trace.get_mcmc_point())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated logic, should be a property of the backend object?

@@ -491,6 +509,10 @@ def __init__(
progressbar=progressbar,
progressbar_theme=progressbar_theme,
)
if self.zarr_recording:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abstraction leaking

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the new functionality, I am deeply against all the if isinstance(..., ZarrTrace) in the codebase. Either our code is supposed to allow different trace backends or it is not, this suggests you want to drop the Ndarray altogether, which fine if you do.

Otherwise all these cases seem like they could be handled by the BaseTrace having sensible default for these methods. We used to have continuation of traces in the past with Ndarray, I don't see anything that fundamentally needs ZarrTrace other than dev interest in it? So just make it raise NotImplementedErrors or make them no-ops and adjust the external code appropriately

I stopped half-way so it was not an extensive review. I think this is a bigger design point that needs decision before settling on the details of the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements feature request major Include in major changes release notes section request discussion samplers trace-backend Traces and ArviZ stuff
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: Add checkpoints during sampling
3 participants