-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow for pymc native samplers to resume sampling from ZarrTrace
#7687
base: main
Are you sure you want to change the base?
Conversation
ZarrTrace
vars=trace_vars, | ||
test_point=initial_point, | ||
) | ||
except TraceAlreadyInitialized: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just InitializedTrace
? Seems a little verbose!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds fine to me, it's an internal thing
if isinstance(trace, ZarrChain): | ||
progress_manager.set_initial_state(*trace.completed_draws_and_divergences()) | ||
progress_manager._progress.update( | ||
progress_manager.tasks[i], | ||
draws=progress_manager.completed_draws | ||
if progress_manager.combined_progress | ||
else progress_manager.draws, | ||
divergences=progress_manager.divergences, | ||
refresh=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't like this abstraction leaking elsewhere, just provide a default to the Ndarray backend that makes it work for either method. In that case I suppose start everything at zero
if isinstance(trace, ZarrChain): | ||
trace.link_stepper(step) | ||
stored_draw_idx = trace._sampling_state.draw_idx[chain] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here all this logic including the old link_stepper
can have a sensible default in the base trace class so you don't need to worry about what kind of trace you have here. Just make link_stepper
a no op and stored_draw_idx
to be zero by default?
if stored_draw_idx > 0: | ||
if stored_sampling_state is not None: | ||
self._step_method.sampling_state = stored_sampling_state | ||
else: | ||
raise RuntimeError( | ||
"Cannot use the supplied ZarrTrace to restart sampling because " | ||
"it has no sampling_state information stored. You will have to " | ||
"resample from scratch." | ||
) | ||
draw = stored_draw_idx | ||
self._write_point(trace.get_mcmc_point()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicated logic, should be a property of the backend object?
@@ -491,6 +509,10 @@ def __init__( | |||
progressbar=progressbar, | |||
progressbar_theme=progressbar_theme, | |||
) | |||
if self.zarr_recording: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
abstraction leaking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the new functionality, I am deeply against all the if isinstance(..., ZarrTrace)
in the codebase. Either our code is supposed to allow different trace backends or it is not, this suggests you want to drop the Ndarray altogether, which fine if you do.
Otherwise all these cases seem like they could be handled by the BaseTrace having sensible default for these methods. We used to have continuation of traces in the past with Ndarray, I don't see anything that fundamentally needs ZarrTrace other than dev interest in it? So just make it raise NotImplementedErrors
or make them no-ops and adjust the external code appropriately
I stopped half-way so it was not an extensive review. I think this is a bigger design point that needs decision before settling on the details of the PR.
Description
Big PR approaching! This finishes adding the ability of pymc native step methods to resume sampling from an existing trace (as long as it's a
ZarrTrace
!). This means that you can now continue tuning or sampling from a pre-existing sample run. For exampleAnother thing is that the
chunks_per_draw
fromZarrTrace
along with its persistent storage backends (likeZipStore
orDirectoryStore
) makes the sampling store the results and final sampling state periodically, so in case of a crash during sampling, you can use the existing store to load the trace usingZarrTrace.from_store
and then resume sampling from there.The only thing that I haven't tested for yet is to add an
Op
that makespm.sample
crash to see if I can reload the partial results from the store and resume sampling. @ricardoV94 gave me some pointers to that, but I won't be working on this for the rest of the month and I thought it best to open a draft PR to kick off any discussion you have or collect feedbackRelated Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7687.org.readthedocs.build/en/7687/