Skip to content

Checkpoint manager

Core Checkpoint Manager for JUNE Simulations

Orchestrates the creation and management of simulation checkpoints, coordinating between different state serialisers and handling MPI distribution.

CheckpointManager

Manages simulation checkpointing with support for MPI parallelisation.

Provides both full checkpoints (complete state) and delta checkpoints (changes since last checkpoint) for efficiency.

Source code in june/checkpointing/checkpoint_manager.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
class CheckpointManager:
    """Manages simulation checkpointing with support for MPI parallelisation.

    Provides both full checkpoints (complete state) and delta checkpoints
    (changes since last checkpoint) for efficiency.

    """

    def __init__(self, simulator, checkpoint_interval_days: float = None, checkpoint_dates: list = None):
        """
        Initialise the checkpoint manager.

        Parameters:
          simulator (Simulator):
            The JUNE simulator instance
          checkpoint_interval_days (float):
            How often to create checkpoints (in simulation days) - used for automatic mode only
          checkpoint_dates (list):
            Specific simulation days/dates when checkpoints should be created
        """
        if checkpoint_dates is not None:
            logger.info(f"Starting - MPI size: {mpi_size}, mode: specific_dates")
        else:
            logger.info(f"Starting - MPI size: {mpi_size}, interval: {checkpoint_interval_days} days")

        self.simulator = simulator
        self.checkpoint_interval = checkpoint_interval_days
        self.checkpoint_dates = checkpoint_dates
        self.completed_checkpoint_dates = set()
        self.last_checkpoint_time = None
        self.checkpoint_history = []
        self.checkpoint_version = "1.0"

        if self.checkpoint_dates is not None:
            logger.info(f"Configuration: checkpoint_version={self.checkpoint_version}, mode=specific_dates")
            logger.info(f"Checkpoint dates: {self.checkpoint_dates}")
            logger.info(f"Checkpoint mode: specific_dates - will checkpoint at simulation days: {self.checkpoint_dates}")
        else:
            logger.info(f"Configuration: checkpoint_version={self.checkpoint_version}, interval={checkpoint_interval_days}")
            logger.info(f"Checkpoint mode: automatic_interval - will checkpoint every {checkpoint_interval_days} days")

        # Initialise state serialisers
        self._init_serialisers()

        logger.info(f"Initialisation complete - ready for checkpointing")

    def _init_serialisers(self):
        """Initialise all state serialization components"""
        logger.info(f"Initialising state serialisers...")

        self.population_serialiser = PopulationHealthSerialiser(self.simulator)
        self.timer_serialiser = TimerStateSerialiser(self.simulator)
        self.interaction_serialiser = InteractionStateSerialiser(self.simulator)
        self.test_and_trace_serialiser = TestAndTraceSerialiser(self.simulator)
        self.random_state_manager = RandomStateManager()
        self.rat_dynamics_serialiser = RatDynamicsSerialiser(self.simulator)
        self.tt_event_recorder_serialiser = TTEventRecorderSerialiser(self.simulator)
        self.school_incident_serialiser = SchoolIncidentSerialiser(self.simulator)

        logger.info(f"All serialisers successfully initialised")

    def should_checkpoint(self) -> bool:
        """Determine if a checkpoint should be created now.


        Returns:
            (bool): True if a checkpoint should be created

        """
        current_time = self.simulator.timer.now
        if hasattr(self.simulator, '_is_resumed_and_first_round') and self.simulator._is_resumed_and_first_round:
            self.simulator._is_resumed_and_first_round = False
            return False

        # Don't checkpoint at the very beginning (time 0)
        if current_time == 0.0:
            return False

        # Check if we have specific checkpoint dates configured
        if self.checkpoint_dates is not None:
            if len(self.checkpoint_dates) == 0:
                return False

            # Check if current time matches any of the specified checkpoint dates
            for checkpoint_date in self.checkpoint_dates:
                if checkpoint_date not in self.completed_checkpoint_dates:
                    # Check if we've reached or passed this checkpoint date
                    if current_time >= checkpoint_date:
                        self.completed_checkpoint_dates.add(checkpoint_date)
                        return True

            return False

        # Only use interval-based checkpointing if checkpoint_dates is None (automatic mode)
        if self.last_checkpoint_time is None:
            # First checkpoint should happen after the interval
            decision = current_time >= self.checkpoint_interval
            return decision

        time_since_last = current_time - self.last_checkpoint_time
        decision = time_since_last >= self.checkpoint_interval
        return decision

    def create_checkpoint(self, checkpoint_path: Path, checkpoint_type: str = "full") -> bool:
        """Create a comprehensive simulation checkpoint.

        Args:
            checkpoint_path (Path): Directory where checkpoint files will be stored
            checkpoint_type (str, optional): Type of checkpoint: "full" or "delta" (Default value = "full")

        Returns:
            (bool): True if checkpoint was created successfully

        """

        current_time = self.simulator.timer.now
        logger.info(f"Starting checkpoint creation: type={checkpoint_type}, time={current_time}, path={checkpoint_path}")

        # Create checkpoint directory
        checkpoint_path.mkdir(parents=True, exist_ok=True)
        logger.info(f"Checkpoint directory created/verified: {checkpoint_path}")

        # Generate checkpoint metadata
        logger.info(f"Generating checkpoint metadata...")
        checkpoint_metadata = self._create_checkpoint_metadata(checkpoint_type)
        logger.info(f"Checkpoint metadata generated - version: {checkpoint_metadata.get('checkpoint_version')}")

        # Collect all state data
        logger.info(f"Starting state data collection...")
        logger.info(f"Rank {mpi_rank}: Collecting simulation state data")
        checkpoint_data = self._collect_simulation_state(checkpoint_type, checkpoint_path)
        logger.info(f"State collection complete - {len(checkpoint_data)} components collected")

        # Save rank-specific data
        rank_file = checkpoint_path / f"checkpoint_rank_{mpi_rank}.h5"
        logger.info(f"Saving checkpoint data to: {rank_file}")
        self._save_checkpoint_data(checkpoint_data, rank_file)
        logger.info(f"Rank-specific data saved successfully")

        # Coordinate across MPI ranks
        if mpi_available:
            logger.info(f"Entering MPI barrier for rank coordination...")
            mpi_comm.Barrier()
            logger.info(f"MPI barrier complete - all ranks synchronised")

        # Master rank creates overall metadata
        if mpi_rank == 0:
            logger.info(f"Master rank creating overall metadata...")
            self._create_master_metadata(checkpoint_path, checkpoint_metadata)
            logger.info(f"Master metadata created")

        # Update checkpoint history
        history_entry = {
            'time': current_time,
            'path': str(checkpoint_path),
            'type': checkpoint_type,
            'timestamp': datetime.datetime.now().isoformat()
        }
        self.last_checkpoint_time = current_time
        self.checkpoint_history.append(history_entry)
        logger.info(f"Checkpoint history updated - total checkpoints: {len(self.checkpoint_history)}")

        logger.info(f"Checkpoint creation successful - type: {checkpoint_type}, time: {current_time}")
        logger.info(f"Rank {mpi_rank}: Checkpoint created successfully at {checkpoint_path}")
        return True

    def _create_checkpoint_metadata(self, checkpoint_type: str) -> Dict[str, Any]:
        """Create checkpoint metadata

        Args:
            checkpoint_type (str): 

        """
        logger.info(f"Creating checkpoint metadata for type: {checkpoint_type}")

        feature_flags = {
            'test_and_trace_enabled': getattr(self.simulator, 'test_and_trace_enabled', False),
            'ratty_dynamics_enabled': getattr(self.simulator, 'ratty_dynamics_enabled', False),
            'friend_hangouts_enabled': getattr(self.simulator, 'friend_hangouts_enabled', False),
            'sexual_encounter_enabled': getattr(self.simulator, 'sexual_encounter_enabled', False)
        }

        # Get disease name from epidemiology configuration
        disease_name = None
        if (hasattr(self.simulator, 'epidemiology') and 
            self.simulator.epidemiology and 
            hasattr(self.simulator.epidemiology, 'infection_selectors') and
            self.simulator.epidemiology.infection_selectors._infection_selectors):
            # Get disease name from the first infection selector
            disease_name = self.simulator.epidemiology.infection_selectors._infection_selectors[0].disease_name

        metadata = {
            'checkpoint_version': self.checkpoint_version,
            'checkpoint_type': checkpoint_type,
            'simulation_time': self.simulator.timer.now,
            'simulation_date': self.simulator.timer.date.isoformat(),
            'total_simulation_days': self.simulator.timer.total_days,
            'mpi_rank': mpi_rank,
            'mpi_size': mpi_size,
            'feature_flags': feature_flags,
            'disease_name': disease_name,
            'creation_timestamp': datetime.datetime.now().isoformat()
        }

        logger.info(f"Metadata created - sim_time: {metadata['simulation_time']}, disease: {disease_name}, flags: {feature_flags}")
        return metadata

    def _collect_simulation_state(self, checkpoint_type: str, checkpoint_path) -> dict[str, any]:
        """Collect all simulation state from various components.

        Args:
            checkpoint_type (str): Type of checkpoint being created
            checkpoint_path: 

        Returns:
            dict[str, Any]: Complete simulation state data

        """
        # Collect all state data from various components
        state_data = {}
        logger.info(f"Starting simulation state collection for {checkpoint_type} checkpoint")

        # Core population health state (always required)
        logger.info(f"Serialising population health state...")
        state_data['population_health'] = self.population_serialiser.serialise(checkpoint_type)
        pop_size = len(state_data['population_health']) if isinstance(state_data['population_health'], (list, dict)) else 'N/A'
        logger.info(f"Population health serialised - data size: {pop_size}")            

        # Timer state
        logger.info(f"Serialising timer state...")
        logger.debug(f"Rank {mpi_rank}: Serialising timer state")
        state_data['timer'] = self.timer_serialiser.serialise(checkpoint_type)
        timer_size = len(state_data['timer']) if isinstance(state_data['timer'], (list, dict)) else 'N/A'
        logger.info(f"Timer state serialised - data size: {timer_size}")

        # Interaction transmission tracking state
        logger.info(f"Serialising interaction transmission state...")
        logger.debug(f"Rank {mpi_rank}: Serialising interaction transmission state")
        state_data['interaction'] = self.interaction_serialiser.serialise(checkpoint_type)
        interaction_size = len(state_data['interaction']) if isinstance(state_data['interaction'], (list, dict)) else 'N/A'
        logger.info(f"Interaction state serialised - data size: {interaction_size}")

        # Test and trace state
        logger.info(f"Serialising test and trace state...")
        logger.debug(f"Rank {mpi_rank}: Serialising test and trace state")
        state_data['test_and_trace'] = self.test_and_trace_serialiser.serialise(checkpoint_type)
        tt_size = len(state_data['test_and_trace']) if isinstance(state_data['test_and_trace'], (list, dict)) else 'N/A'
        logger.info(f"Test and trace state serialised - data size: {tt_size}")

        # Rat dynamics state
        logger.info(f"Serialising rat dynamics state...")
        logger.debug(f"Rank {mpi_rank}: Serialising rat dynamics state")
        state_data['rat_dynamics'] = self.rat_dynamics_serialiser.serialise(checkpoint_type)
        rat_size = len(state_data['rat_dynamics']) if isinstance(state_data['rat_dynamics'], (list, dict)) else 'N/A'
        logger.info(f"Rat dynamics state serialised - data size: {rat_size}")

        # TTEventRecorder state (daily/cumulative test and trace data)
        logger.info(f"Serialising TTEventRecorder state...")
        logger.debug(f"Rank {mpi_rank}: Serialising TTEventRecorder state")
        state_data['tt_event_recorder'] = self.tt_event_recorder_serialiser.serialise(checkpoint_type)
        tte_size = len(state_data['tt_event_recorder']) if isinstance(state_data['tt_event_recorder'], (list, dict)) else 'N/A'
        logger.info(f"TTEventRecorder state serialised - data size: {tte_size}")

        # School incident tracking state (for NotSendingKidsToSchool policy)
        logger.info(f"Serialising school incident tracking state...")
        logger.debug(f"Rank {mpi_rank}: Serialising school incident tracking state")
        state_data['school_incidents'] = self.school_incident_serialiser.serialise(checkpoint_type)
        school_size = len(state_data['school_incidents']) if isinstance(state_data['school_incidents'], (list, dict)) else 'N/A'
        logger.info(f"School incident state serialised - data size: {school_size}")

        # Random number generator states - use dedicated manager
        logger.info(f"Capturing RNG states with dedicated manager...")
        random_states = self.random_state_manager.capture_states()
        logger.info(f"RNG states captured - count: {len(random_states) if isinstance(random_states, (list, dict)) else 'N/A'}")

        # Save random states separately
        logger.info(f"Saving RNG states to checkpoint path...")
        if not self.random_state_manager.save_states(random_states, checkpoint_path):
            logger.error("CRITICAL ERROR: Failed to save random states")
            raise RuntimeError("Random state saving failed")
        logger.info(f"RNG states saved successfully")

        total_components = len(state_data)
        logger.info(f"State collection COMPLETE - {total_components} components collected")
        return state_data

    def _save_checkpoint_data(self, checkpoint_data: dict[str, any], file_path: Path):
        """Save checkpoint data to HDF5 file.

        Args:
            checkpoint_data (Dict[str, any]): The checkpoint data to save
            file_path (Path): Path to the HDF5 file

        """
        logger.info(f"Opening HDF5 file for writing: {file_path}")

        with h5py.File(file_path, 'w') as f:
            logger.info(f"HDF5 file opened successfully")

            # Save metadata
            logger.info(f"Creating metadata group...")
            metadata_group = f.create_group('metadata')
            metadata_group.attrs['checkpoint_version'] = self.checkpoint_version
            metadata_group.attrs['mpi_rank'] = mpi_rank
            metadata_group.attrs['mpi_size'] = mpi_size
            metadata_group.attrs['simulation_time'] = self.simulator.timer.now
            metadata_group.attrs['creation_timestamp'] = datetime.datetime.now().isoformat()
            logger.info(f"Metadata group created with 5 attributes")

            # Save each component's data (excluding random states)
            processed_components = 0
            for component_name, component_data in checkpoint_data.items():
                if component_data is not None:
                    logger.info(f"Processing component: {component_name}")
                    # Note: Removed convert_numpy_types to preserve numerical accuracy
                    logger.info(f"Processing component data for: {component_name}")
                    self._save_component_data(f, component_name, component_data)
                    processed_components += 1
                    logger.info(f"Component saved: {component_name}")
                else:
                    logger.info(f"Skipping component {component_name} - data is None")

            logger.info(f"HDF5 save complete - {processed_components} components processed")

    def _save_component_data(self, hdf5_file, component_name: str, data: any):
        """Save a component's data to the HDF5 file.

        Args:
            hdf5_file (h5py.File): Open HDF5 file handle
            component_name (str): Name of the component
            data (any): Data to save

        """
        logger.info(f"Creating HDF5 group for component: {component_name}")
        group = hdf5_file.create_group(component_name)

        if isinstance(data, dict):
            logger.info(f"Processing dict data for {component_name} - {len(data)} keys")
            processed_keys = 0

            for key, value in data.items():
                logger.info(f"Processing key: {key} (type: {type(value).__name__})")

                if isinstance(value, (list, tuple)):
                    # Handle lists/tuples more carefully
                    if len(value) > 0:
                        logger.info(f"List/tuple data for {key} - length: {len(value)}")
                        # Check if all elements are the same type
                        first_type = type(value[0])
                        if all(isinstance(item, first_type) for item in value):
                            if isinstance(value[0], (int, float, bool)):
                                # Homogeneous numeric data - safe to convert to array
                                logger.info(f"Creating numeric dataset for {key}")
                                group.create_dataset(key, data=np.array(value))
                            else:
                                # Complex objects - convert to JSON strings
                                logger.info(f"Converting complex objects to JSON strings for {key}")
                                json_strings = [json.dumps(item, default=str) for item in value]
                                group.create_dataset(key, data=json_strings, dtype=h5py.string_dtype())
                        else:
                            # Heterogeneous data - convert to JSON strings
                            logger.info(f"Converting heterogeneous data to JSON strings for {key}")
                            json_strings = [json.dumps(item, default=str) for item in value]
                            group.create_dataset(key, data=json_strings, dtype=h5py.string_dtype())
                    else:
                        # Empty list - create empty dataset
                        logger.info(f"Creating empty dataset for {key}")
                        group.create_dataset(key, data=np.array([]), dtype='f')
                elif isinstance(value, np.ndarray):
                    # Already a numpy array
                    logger.info(f"Saving numpy array for {key} - shape: {value.shape}")
                    group.create_dataset(key, data=value)
                elif isinstance(value, (int, float, bool, str)):
                    # Simple scalar values
                    logger.info(f"Saving scalar value for {key}")
                    group.attrs[key] = value
                elif value is None:
                    logger.info(f"Saving None value for {key}")
                    group.attrs[key] = "None"
                else:
                    # Convert complex objects to JSON
                    logger.info(f"Converting complex object to JSON for {key}")
                    group.attrs[key] = json.dumps(value, default=str)

                processed_keys += 1

            logger.info(f"Component {component_name} saved - {processed_keys} keys processed")

        else:
            # Handle non-dict data
            logger.info(f"Processing non-dict data for {component_name} (type: {type(data).__name__})")

            if isinstance(data, np.ndarray):
                logger.info(f"Saving numpy array data - shape: {data.shape}")
                group.create_dataset('data', data=data)
            else:
                logger.info(f"Converting data to JSON")
                group.attrs['data'] = json.dumps(data, default=str)

            logger.info(f"Non-dict component {component_name} saved")


    def _create_master_metadata(self, checkpoint_path: Path, metadata: dict[str, any]):
        """Create master metadata file (only on rank 0)

        Args:
            checkpoint_path (Path): 
            metadata (dict[str, any]): 

        """
        logger.info(f"Creating master metadata file...")

        metadata_file = checkpoint_path / "checkpoint_metadata.json"
        logger.info(f"Master metadata file path: {metadata_file}")

        # Add information about all ranks
        metadata['total_ranks'] = mpi_size
        metadata['rank_files'] = [f"checkpoint_rank_{rank}.h5" for rank in range(mpi_size)]
        logger.info(f"Added rank information - total_ranks: {mpi_size}, files: {len(metadata['rank_files'])}")

        logger.info(f"Writing master metadata to JSON file...")
        with open(metadata_file, 'w') as f:
            json.dump(metadata, f, indent=2, default=str)

        logger.info(f"Master metadata file created successfully - {len(metadata)} keys written")

    def list_available_checkpoints(self, checkpoint_dir: Path) -> list[dict[str, any]]:
        """List all available checkpoints in a directory.

        Args:
            checkpoint_dir (Path): Directory to search for checkpoints

        Returns:
            list[dict[str, any]]: List of checkpoint information

        """
        checkpoints = []

        if not checkpoint_dir.exists():
            return checkpoints

        for item in checkpoint_dir.iterdir():
            if not item.is_dir() or item.name.startswith('.'):
                continue

            metadata_file = item / "checkpoint_metadata.json"

            # Skip directories that don't have checkpoint metadata
            if not metadata_file.exists():
                continue

            with open(metadata_file, 'r') as f:
                metadata = json.load(f)
            metadata['path'] = str(item)
            checkpoints.append(metadata)

        # Sort by simulation time
        checkpoints.sort(key=lambda x: x.get('simulation_time', 0))
        return checkpoints

    def get_latest_checkpoint(self, checkpoint_dir: Path) -> Optional[dict[str, any]]:
        """Get the most recent checkpoint.

        Args:
            checkpoint_dir (Path): Directory to search for checkpoints

        Returns:
            Optional[Dict[str, Any]]: Latest checkpoint metadata, or None if no checkpoints found

        """
        checkpoints = self.list_available_checkpoints(checkpoint_dir)
        return checkpoints[-1] if checkpoints else None

    def reset_completed_checkpoint_dates_after_restoration(self, restored_simulation_time: float):
        """Reset completed checkpoint dates after restoration to allow future checkpoints.

        After restoring from a checkpoint, we need to reset the tracking of completed
        checkpoint dates so that any dates after the restored time can still be used
        for creating new checkpoints.

        Args:
            restored_simulation_time (float): The simulation time that was restored from the checkpoint

        """
        if self.checkpoint_dates is None:
            # Not using specific dates mode, nothing to reset
            return

        # FORCE CLEAR ALL COMPLETED DATES FOR CHILD RUNS
        # This ensures that child runs can create checkpoints at any configured date
        original_completed = self.completed_checkpoint_dates.copy()
        self.completed_checkpoint_dates = set()  # Clear everything

        # Log what was reset
        if original_completed:
            logger.info(f"FORCED RESET: Cleared all completed checkpoint dates after restoration from time {restored_simulation_time}")
            logger.info(f"  - Cleared dates: {sorted(original_completed)}")
            logger.info(f"  - All configured dates now available: {sorted(self.checkpoint_dates)}")
        else:
            logger.info(f"No checkpoint dates needed resetting after restoration from time {restored_simulation_time}")

        # Also reset last checkpoint time if it's after the restored time
        if self.last_checkpoint_time and self.last_checkpoint_time > restored_simulation_time:
            logger.info(f"Reset last_checkpoint_time from {self.last_checkpoint_time} to None after restoration")
            self.last_checkpoint_time = None

__init__(simulator, checkpoint_interval_days=None, checkpoint_dates=None)

Initialise the checkpoint manager.

Parameters:

Name Type Description Default
simulator Simulator

The JUNE simulator instance

required
checkpoint_interval_days float

How often to create checkpoints (in simulation days) - used for automatic mode only

None
checkpoint_dates list

Specific simulation days/dates when checkpoints should be created

None
Source code in june/checkpointing/checkpoint_manager.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def __init__(self, simulator, checkpoint_interval_days: float = None, checkpoint_dates: list = None):
    """
    Initialise the checkpoint manager.

    Parameters:
      simulator (Simulator):
        The JUNE simulator instance
      checkpoint_interval_days (float):
        How often to create checkpoints (in simulation days) - used for automatic mode only
      checkpoint_dates (list):
        Specific simulation days/dates when checkpoints should be created
    """
    if checkpoint_dates is not None:
        logger.info(f"Starting - MPI size: {mpi_size}, mode: specific_dates")
    else:
        logger.info(f"Starting - MPI size: {mpi_size}, interval: {checkpoint_interval_days} days")

    self.simulator = simulator
    self.checkpoint_interval = checkpoint_interval_days
    self.checkpoint_dates = checkpoint_dates
    self.completed_checkpoint_dates = set()
    self.last_checkpoint_time = None
    self.checkpoint_history = []
    self.checkpoint_version = "1.0"

    if self.checkpoint_dates is not None:
        logger.info(f"Configuration: checkpoint_version={self.checkpoint_version}, mode=specific_dates")
        logger.info(f"Checkpoint dates: {self.checkpoint_dates}")
        logger.info(f"Checkpoint mode: specific_dates - will checkpoint at simulation days: {self.checkpoint_dates}")
    else:
        logger.info(f"Configuration: checkpoint_version={self.checkpoint_version}, interval={checkpoint_interval_days}")
        logger.info(f"Checkpoint mode: automatic_interval - will checkpoint every {checkpoint_interval_days} days")

    # Initialise state serialisers
    self._init_serialisers()

    logger.info(f"Initialisation complete - ready for checkpointing")

create_checkpoint(checkpoint_path, checkpoint_type='full')

Create a comprehensive simulation checkpoint.

Parameters:

Name Type Description Default
checkpoint_path Path

Directory where checkpoint files will be stored

required
checkpoint_type str

Type of checkpoint: "full" or "delta" (Default value = "full")

'full'

Returns:

Type Description
bool

True if checkpoint was created successfully

Source code in june/checkpointing/checkpoint_manager.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def create_checkpoint(self, checkpoint_path: Path, checkpoint_type: str = "full") -> bool:
    """Create a comprehensive simulation checkpoint.

    Args:
        checkpoint_path (Path): Directory where checkpoint files will be stored
        checkpoint_type (str, optional): Type of checkpoint: "full" or "delta" (Default value = "full")

    Returns:
        (bool): True if checkpoint was created successfully

    """

    current_time = self.simulator.timer.now
    logger.info(f"Starting checkpoint creation: type={checkpoint_type}, time={current_time}, path={checkpoint_path}")

    # Create checkpoint directory
    checkpoint_path.mkdir(parents=True, exist_ok=True)
    logger.info(f"Checkpoint directory created/verified: {checkpoint_path}")

    # Generate checkpoint metadata
    logger.info(f"Generating checkpoint metadata...")
    checkpoint_metadata = self._create_checkpoint_metadata(checkpoint_type)
    logger.info(f"Checkpoint metadata generated - version: {checkpoint_metadata.get('checkpoint_version')}")

    # Collect all state data
    logger.info(f"Starting state data collection...")
    logger.info(f"Rank {mpi_rank}: Collecting simulation state data")
    checkpoint_data = self._collect_simulation_state(checkpoint_type, checkpoint_path)
    logger.info(f"State collection complete - {len(checkpoint_data)} components collected")

    # Save rank-specific data
    rank_file = checkpoint_path / f"checkpoint_rank_{mpi_rank}.h5"
    logger.info(f"Saving checkpoint data to: {rank_file}")
    self._save_checkpoint_data(checkpoint_data, rank_file)
    logger.info(f"Rank-specific data saved successfully")

    # Coordinate across MPI ranks
    if mpi_available:
        logger.info(f"Entering MPI barrier for rank coordination...")
        mpi_comm.Barrier()
        logger.info(f"MPI barrier complete - all ranks synchronised")

    # Master rank creates overall metadata
    if mpi_rank == 0:
        logger.info(f"Master rank creating overall metadata...")
        self._create_master_metadata(checkpoint_path, checkpoint_metadata)
        logger.info(f"Master metadata created")

    # Update checkpoint history
    history_entry = {
        'time': current_time,
        'path': str(checkpoint_path),
        'type': checkpoint_type,
        'timestamp': datetime.datetime.now().isoformat()
    }
    self.last_checkpoint_time = current_time
    self.checkpoint_history.append(history_entry)
    logger.info(f"Checkpoint history updated - total checkpoints: {len(self.checkpoint_history)}")

    logger.info(f"Checkpoint creation successful - type: {checkpoint_type}, time: {current_time}")
    logger.info(f"Rank {mpi_rank}: Checkpoint created successfully at {checkpoint_path}")
    return True

get_latest_checkpoint(checkpoint_dir)

Get the most recent checkpoint.

Parameters:

Name Type Description Default
checkpoint_dir Path

Directory to search for checkpoints

required

Returns:

Type Description
Optional[dict[str, any]]

Optional[Dict[str, Any]]: Latest checkpoint metadata, or None if no checkpoints found

Source code in june/checkpointing/checkpoint_manager.py
490
491
492
493
494
495
496
497
498
499
500
501
def get_latest_checkpoint(self, checkpoint_dir: Path) -> Optional[dict[str, any]]:
    """Get the most recent checkpoint.

    Args:
        checkpoint_dir (Path): Directory to search for checkpoints

    Returns:
        Optional[Dict[str, Any]]: Latest checkpoint metadata, or None if no checkpoints found

    """
    checkpoints = self.list_available_checkpoints(checkpoint_dir)
    return checkpoints[-1] if checkpoints else None

list_available_checkpoints(checkpoint_dir)

List all available checkpoints in a directory.

Parameters:

Name Type Description Default
checkpoint_dir Path

Directory to search for checkpoints

required

Returns:

Type Description
list[dict[str, any]]

list[dict[str, any]]: List of checkpoint information

Source code in june/checkpointing/checkpoint_manager.py
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
def list_available_checkpoints(self, checkpoint_dir: Path) -> list[dict[str, any]]:
    """List all available checkpoints in a directory.

    Args:
        checkpoint_dir (Path): Directory to search for checkpoints

    Returns:
        list[dict[str, any]]: List of checkpoint information

    """
    checkpoints = []

    if not checkpoint_dir.exists():
        return checkpoints

    for item in checkpoint_dir.iterdir():
        if not item.is_dir() or item.name.startswith('.'):
            continue

        metadata_file = item / "checkpoint_metadata.json"

        # Skip directories that don't have checkpoint metadata
        if not metadata_file.exists():
            continue

        with open(metadata_file, 'r') as f:
            metadata = json.load(f)
        metadata['path'] = str(item)
        checkpoints.append(metadata)

    # Sort by simulation time
    checkpoints.sort(key=lambda x: x.get('simulation_time', 0))
    return checkpoints

reset_completed_checkpoint_dates_after_restoration(restored_simulation_time)

Reset completed checkpoint dates after restoration to allow future checkpoints.

After restoring from a checkpoint, we need to reset the tracking of completed checkpoint dates so that any dates after the restored time can still be used for creating new checkpoints.

Parameters:

Name Type Description Default
restored_simulation_time float

The simulation time that was restored from the checkpoint

required
Source code in june/checkpointing/checkpoint_manager.py
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
def reset_completed_checkpoint_dates_after_restoration(self, restored_simulation_time: float):
    """Reset completed checkpoint dates after restoration to allow future checkpoints.

    After restoring from a checkpoint, we need to reset the tracking of completed
    checkpoint dates so that any dates after the restored time can still be used
    for creating new checkpoints.

    Args:
        restored_simulation_time (float): The simulation time that was restored from the checkpoint

    """
    if self.checkpoint_dates is None:
        # Not using specific dates mode, nothing to reset
        return

    # FORCE CLEAR ALL COMPLETED DATES FOR CHILD RUNS
    # This ensures that child runs can create checkpoints at any configured date
    original_completed = self.completed_checkpoint_dates.copy()
    self.completed_checkpoint_dates = set()  # Clear everything

    # Log what was reset
    if original_completed:
        logger.info(f"FORCED RESET: Cleared all completed checkpoint dates after restoration from time {restored_simulation_time}")
        logger.info(f"  - Cleared dates: {sorted(original_completed)}")
        logger.info(f"  - All configured dates now available: {sorted(self.checkpoint_dates)}")
    else:
        logger.info(f"No checkpoint dates needed resetting after restoration from time {restored_simulation_time}")

    # Also reset last checkpoint time if it's after the restored time
    if self.last_checkpoint_time and self.last_checkpoint_time > restored_simulation_time:
        logger.info(f"Reset last_checkpoint_time from {self.last_checkpoint_time} to None after restoration")
        self.last_checkpoint_time = None

should_checkpoint()

Determine if a checkpoint should be created now.

Returns:

Type Description
bool

True if a checkpoint should be created

Source code in june/checkpointing/checkpoint_manager.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def should_checkpoint(self) -> bool:
    """Determine if a checkpoint should be created now.


    Returns:
        (bool): True if a checkpoint should be created

    """
    current_time = self.simulator.timer.now
    if hasattr(self.simulator, '_is_resumed_and_first_round') and self.simulator._is_resumed_and_first_round:
        self.simulator._is_resumed_and_first_round = False
        return False

    # Don't checkpoint at the very beginning (time 0)
    if current_time == 0.0:
        return False

    # Check if we have specific checkpoint dates configured
    if self.checkpoint_dates is not None:
        if len(self.checkpoint_dates) == 0:
            return False

        # Check if current time matches any of the specified checkpoint dates
        for checkpoint_date in self.checkpoint_dates:
            if checkpoint_date not in self.completed_checkpoint_dates:
                # Check if we've reached or passed this checkpoint date
                if current_time >= checkpoint_date:
                    self.completed_checkpoint_dates.add(checkpoint_date)
                    return True

        return False

    # Only use interval-based checkpointing if checkpoint_dates is None (automatic mode)
    if self.last_checkpoint_time is None:
        # First checkpoint should happen after the interval
        decision = current_time >= self.checkpoint_interval
        return decision

    time_since_last = current_time - self.last_checkpoint_time
    decision = time_since_last >= self.checkpoint_interval
    return decision