30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534 | class CheckpointManager:
"""Manages simulation checkpointing with support for MPI parallelisation.
Provides both full checkpoints (complete state) and delta checkpoints
(changes since last checkpoint) for efficiency.
"""
def __init__(self, simulator, checkpoint_interval_days: float = None, checkpoint_dates: list = None):
"""
Initialise the checkpoint manager.
Parameters:
simulator (Simulator):
The JUNE simulator instance
checkpoint_interval_days (float):
How often to create checkpoints (in simulation days) - used for automatic mode only
checkpoint_dates (list):
Specific simulation days/dates when checkpoints should be created
"""
if checkpoint_dates is not None:
logger.info(f"Starting - MPI size: {mpi_size}, mode: specific_dates")
else:
logger.info(f"Starting - MPI size: {mpi_size}, interval: {checkpoint_interval_days} days")
self.simulator = simulator
self.checkpoint_interval = checkpoint_interval_days
self.checkpoint_dates = checkpoint_dates
self.completed_checkpoint_dates = set()
self.last_checkpoint_time = None
self.checkpoint_history = []
self.checkpoint_version = "1.0"
if self.checkpoint_dates is not None:
logger.info(f"Configuration: checkpoint_version={self.checkpoint_version}, mode=specific_dates")
logger.info(f"Checkpoint dates: {self.checkpoint_dates}")
logger.info(f"Checkpoint mode: specific_dates - will checkpoint at simulation days: {self.checkpoint_dates}")
else:
logger.info(f"Configuration: checkpoint_version={self.checkpoint_version}, interval={checkpoint_interval_days}")
logger.info(f"Checkpoint mode: automatic_interval - will checkpoint every {checkpoint_interval_days} days")
# Initialise state serialisers
self._init_serialisers()
logger.info(f"Initialisation complete - ready for checkpointing")
def _init_serialisers(self):
"""Initialise all state serialization components"""
logger.info(f"Initialising state serialisers...")
self.population_serialiser = PopulationHealthSerialiser(self.simulator)
self.timer_serialiser = TimerStateSerialiser(self.simulator)
self.interaction_serialiser = InteractionStateSerialiser(self.simulator)
self.test_and_trace_serialiser = TestAndTraceSerialiser(self.simulator)
self.random_state_manager = RandomStateManager()
self.rat_dynamics_serialiser = RatDynamicsSerialiser(self.simulator)
self.tt_event_recorder_serialiser = TTEventRecorderSerialiser(self.simulator)
self.school_incident_serialiser = SchoolIncidentSerialiser(self.simulator)
logger.info(f"All serialisers successfully initialised")
def should_checkpoint(self) -> bool:
"""Determine if a checkpoint should be created now.
Returns:
(bool): True if a checkpoint should be created
"""
current_time = self.simulator.timer.now
if hasattr(self.simulator, '_is_resumed_and_first_round') and self.simulator._is_resumed_and_first_round:
self.simulator._is_resumed_and_first_round = False
return False
# Don't checkpoint at the very beginning (time 0)
if current_time == 0.0:
return False
# Check if we have specific checkpoint dates configured
if self.checkpoint_dates is not None:
if len(self.checkpoint_dates) == 0:
return False
# Check if current time matches any of the specified checkpoint dates
for checkpoint_date in self.checkpoint_dates:
if checkpoint_date not in self.completed_checkpoint_dates:
# Check if we've reached or passed this checkpoint date
if current_time >= checkpoint_date:
self.completed_checkpoint_dates.add(checkpoint_date)
return True
return False
# Only use interval-based checkpointing if checkpoint_dates is None (automatic mode)
if self.last_checkpoint_time is None:
# First checkpoint should happen after the interval
decision = current_time >= self.checkpoint_interval
return decision
time_since_last = current_time - self.last_checkpoint_time
decision = time_since_last >= self.checkpoint_interval
return decision
def create_checkpoint(self, checkpoint_path: Path, checkpoint_type: str = "full") -> bool:
"""Create a comprehensive simulation checkpoint.
Args:
checkpoint_path (Path): Directory where checkpoint files will be stored
checkpoint_type (str, optional): Type of checkpoint: "full" or "delta" (Default value = "full")
Returns:
(bool): True if checkpoint was created successfully
"""
current_time = self.simulator.timer.now
logger.info(f"Starting checkpoint creation: type={checkpoint_type}, time={current_time}, path={checkpoint_path}")
# Create checkpoint directory
checkpoint_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Checkpoint directory created/verified: {checkpoint_path}")
# Generate checkpoint metadata
logger.info(f"Generating checkpoint metadata...")
checkpoint_metadata = self._create_checkpoint_metadata(checkpoint_type)
logger.info(f"Checkpoint metadata generated - version: {checkpoint_metadata.get('checkpoint_version')}")
# Collect all state data
logger.info(f"Starting state data collection...")
logger.info(f"Rank {mpi_rank}: Collecting simulation state data")
checkpoint_data = self._collect_simulation_state(checkpoint_type, checkpoint_path)
logger.info(f"State collection complete - {len(checkpoint_data)} components collected")
# Save rank-specific data
rank_file = checkpoint_path / f"checkpoint_rank_{mpi_rank}.h5"
logger.info(f"Saving checkpoint data to: {rank_file}")
self._save_checkpoint_data(checkpoint_data, rank_file)
logger.info(f"Rank-specific data saved successfully")
# Coordinate across MPI ranks
if mpi_available:
logger.info(f"Entering MPI barrier for rank coordination...")
mpi_comm.Barrier()
logger.info(f"MPI barrier complete - all ranks synchronised")
# Master rank creates overall metadata
if mpi_rank == 0:
logger.info(f"Master rank creating overall metadata...")
self._create_master_metadata(checkpoint_path, checkpoint_metadata)
logger.info(f"Master metadata created")
# Update checkpoint history
history_entry = {
'time': current_time,
'path': str(checkpoint_path),
'type': checkpoint_type,
'timestamp': datetime.datetime.now().isoformat()
}
self.last_checkpoint_time = current_time
self.checkpoint_history.append(history_entry)
logger.info(f"Checkpoint history updated - total checkpoints: {len(self.checkpoint_history)}")
logger.info(f"Checkpoint creation successful - type: {checkpoint_type}, time: {current_time}")
logger.info(f"Rank {mpi_rank}: Checkpoint created successfully at {checkpoint_path}")
return True
def _create_checkpoint_metadata(self, checkpoint_type: str) -> Dict[str, Any]:
"""Create checkpoint metadata
Args:
checkpoint_type (str):
"""
logger.info(f"Creating checkpoint metadata for type: {checkpoint_type}")
feature_flags = {
'test_and_trace_enabled': getattr(self.simulator, 'test_and_trace_enabled', False),
'ratty_dynamics_enabled': getattr(self.simulator, 'ratty_dynamics_enabled', False),
'friend_hangouts_enabled': getattr(self.simulator, 'friend_hangouts_enabled', False),
'sexual_encounter_enabled': getattr(self.simulator, 'sexual_encounter_enabled', False)
}
# Get disease name from epidemiology configuration
disease_name = None
if (hasattr(self.simulator, 'epidemiology') and
self.simulator.epidemiology and
hasattr(self.simulator.epidemiology, 'infection_selectors') and
self.simulator.epidemiology.infection_selectors._infection_selectors):
# Get disease name from the first infection selector
disease_name = self.simulator.epidemiology.infection_selectors._infection_selectors[0].disease_name
metadata = {
'checkpoint_version': self.checkpoint_version,
'checkpoint_type': checkpoint_type,
'simulation_time': self.simulator.timer.now,
'simulation_date': self.simulator.timer.date.isoformat(),
'total_simulation_days': self.simulator.timer.total_days,
'mpi_rank': mpi_rank,
'mpi_size': mpi_size,
'feature_flags': feature_flags,
'disease_name': disease_name,
'creation_timestamp': datetime.datetime.now().isoformat()
}
logger.info(f"Metadata created - sim_time: {metadata['simulation_time']}, disease: {disease_name}, flags: {feature_flags}")
return metadata
def _collect_simulation_state(self, checkpoint_type: str, checkpoint_path) -> dict[str, any]:
"""Collect all simulation state from various components.
Args:
checkpoint_type (str): Type of checkpoint being created
checkpoint_path:
Returns:
dict[str, Any]: Complete simulation state data
"""
# Collect all state data from various components
state_data = {}
logger.info(f"Starting simulation state collection for {checkpoint_type} checkpoint")
# Core population health state (always required)
logger.info(f"Serialising population health state...")
state_data['population_health'] = self.population_serialiser.serialise(checkpoint_type)
pop_size = len(state_data['population_health']) if isinstance(state_data['population_health'], (list, dict)) else 'N/A'
logger.info(f"Population health serialised - data size: {pop_size}")
# Timer state
logger.info(f"Serialising timer state...")
logger.debug(f"Rank {mpi_rank}: Serialising timer state")
state_data['timer'] = self.timer_serialiser.serialise(checkpoint_type)
timer_size = len(state_data['timer']) if isinstance(state_data['timer'], (list, dict)) else 'N/A'
logger.info(f"Timer state serialised - data size: {timer_size}")
# Interaction transmission tracking state
logger.info(f"Serialising interaction transmission state...")
logger.debug(f"Rank {mpi_rank}: Serialising interaction transmission state")
state_data['interaction'] = self.interaction_serialiser.serialise(checkpoint_type)
interaction_size = len(state_data['interaction']) if isinstance(state_data['interaction'], (list, dict)) else 'N/A'
logger.info(f"Interaction state serialised - data size: {interaction_size}")
# Test and trace state
logger.info(f"Serialising test and trace state...")
logger.debug(f"Rank {mpi_rank}: Serialising test and trace state")
state_data['test_and_trace'] = self.test_and_trace_serialiser.serialise(checkpoint_type)
tt_size = len(state_data['test_and_trace']) if isinstance(state_data['test_and_trace'], (list, dict)) else 'N/A'
logger.info(f"Test and trace state serialised - data size: {tt_size}")
# Rat dynamics state
logger.info(f"Serialising rat dynamics state...")
logger.debug(f"Rank {mpi_rank}: Serialising rat dynamics state")
state_data['rat_dynamics'] = self.rat_dynamics_serialiser.serialise(checkpoint_type)
rat_size = len(state_data['rat_dynamics']) if isinstance(state_data['rat_dynamics'], (list, dict)) else 'N/A'
logger.info(f"Rat dynamics state serialised - data size: {rat_size}")
# TTEventRecorder state (daily/cumulative test and trace data)
logger.info(f"Serialising TTEventRecorder state...")
logger.debug(f"Rank {mpi_rank}: Serialising TTEventRecorder state")
state_data['tt_event_recorder'] = self.tt_event_recorder_serialiser.serialise(checkpoint_type)
tte_size = len(state_data['tt_event_recorder']) if isinstance(state_data['tt_event_recorder'], (list, dict)) else 'N/A'
logger.info(f"TTEventRecorder state serialised - data size: {tte_size}")
# School incident tracking state (for NotSendingKidsToSchool policy)
logger.info(f"Serialising school incident tracking state...")
logger.debug(f"Rank {mpi_rank}: Serialising school incident tracking state")
state_data['school_incidents'] = self.school_incident_serialiser.serialise(checkpoint_type)
school_size = len(state_data['school_incidents']) if isinstance(state_data['school_incidents'], (list, dict)) else 'N/A'
logger.info(f"School incident state serialised - data size: {school_size}")
# Random number generator states - use dedicated manager
logger.info(f"Capturing RNG states with dedicated manager...")
random_states = self.random_state_manager.capture_states()
logger.info(f"RNG states captured - count: {len(random_states) if isinstance(random_states, (list, dict)) else 'N/A'}")
# Save random states separately
logger.info(f"Saving RNG states to checkpoint path...")
if not self.random_state_manager.save_states(random_states, checkpoint_path):
logger.error("CRITICAL ERROR: Failed to save random states")
raise RuntimeError("Random state saving failed")
logger.info(f"RNG states saved successfully")
total_components = len(state_data)
logger.info(f"State collection COMPLETE - {total_components} components collected")
return state_data
def _save_checkpoint_data(self, checkpoint_data: dict[str, any], file_path: Path):
"""Save checkpoint data to HDF5 file.
Args:
checkpoint_data (Dict[str, any]): The checkpoint data to save
file_path (Path): Path to the HDF5 file
"""
logger.info(f"Opening HDF5 file for writing: {file_path}")
with h5py.File(file_path, 'w') as f:
logger.info(f"HDF5 file opened successfully")
# Save metadata
logger.info(f"Creating metadata group...")
metadata_group = f.create_group('metadata')
metadata_group.attrs['checkpoint_version'] = self.checkpoint_version
metadata_group.attrs['mpi_rank'] = mpi_rank
metadata_group.attrs['mpi_size'] = mpi_size
metadata_group.attrs['simulation_time'] = self.simulator.timer.now
metadata_group.attrs['creation_timestamp'] = datetime.datetime.now().isoformat()
logger.info(f"Metadata group created with 5 attributes")
# Save each component's data (excluding random states)
processed_components = 0
for component_name, component_data in checkpoint_data.items():
if component_data is not None:
logger.info(f"Processing component: {component_name}")
# Note: Removed convert_numpy_types to preserve numerical accuracy
logger.info(f"Processing component data for: {component_name}")
self._save_component_data(f, component_name, component_data)
processed_components += 1
logger.info(f"Component saved: {component_name}")
else:
logger.info(f"Skipping component {component_name} - data is None")
logger.info(f"HDF5 save complete - {processed_components} components processed")
def _save_component_data(self, hdf5_file, component_name: str, data: any):
"""Save a component's data to the HDF5 file.
Args:
hdf5_file (h5py.File): Open HDF5 file handle
component_name (str): Name of the component
data (any): Data to save
"""
logger.info(f"Creating HDF5 group for component: {component_name}")
group = hdf5_file.create_group(component_name)
if isinstance(data, dict):
logger.info(f"Processing dict data for {component_name} - {len(data)} keys")
processed_keys = 0
for key, value in data.items():
logger.info(f"Processing key: {key} (type: {type(value).__name__})")
if isinstance(value, (list, tuple)):
# Handle lists/tuples more carefully
if len(value) > 0:
logger.info(f"List/tuple data for {key} - length: {len(value)}")
# Check if all elements are the same type
first_type = type(value[0])
if all(isinstance(item, first_type) for item in value):
if isinstance(value[0], (int, float, bool)):
# Homogeneous numeric data - safe to convert to array
logger.info(f"Creating numeric dataset for {key}")
group.create_dataset(key, data=np.array(value))
else:
# Complex objects - convert to JSON strings
logger.info(f"Converting complex objects to JSON strings for {key}")
json_strings = [json.dumps(item, default=str) for item in value]
group.create_dataset(key, data=json_strings, dtype=h5py.string_dtype())
else:
# Heterogeneous data - convert to JSON strings
logger.info(f"Converting heterogeneous data to JSON strings for {key}")
json_strings = [json.dumps(item, default=str) for item in value]
group.create_dataset(key, data=json_strings, dtype=h5py.string_dtype())
else:
# Empty list - create empty dataset
logger.info(f"Creating empty dataset for {key}")
group.create_dataset(key, data=np.array([]), dtype='f')
elif isinstance(value, np.ndarray):
# Already a numpy array
logger.info(f"Saving numpy array for {key} - shape: {value.shape}")
group.create_dataset(key, data=value)
elif isinstance(value, (int, float, bool, str)):
# Simple scalar values
logger.info(f"Saving scalar value for {key}")
group.attrs[key] = value
elif value is None:
logger.info(f"Saving None value for {key}")
group.attrs[key] = "None"
else:
# Convert complex objects to JSON
logger.info(f"Converting complex object to JSON for {key}")
group.attrs[key] = json.dumps(value, default=str)
processed_keys += 1
logger.info(f"Component {component_name} saved - {processed_keys} keys processed")
else:
# Handle non-dict data
logger.info(f"Processing non-dict data for {component_name} (type: {type(data).__name__})")
if isinstance(data, np.ndarray):
logger.info(f"Saving numpy array data - shape: {data.shape}")
group.create_dataset('data', data=data)
else:
logger.info(f"Converting data to JSON")
group.attrs['data'] = json.dumps(data, default=str)
logger.info(f"Non-dict component {component_name} saved")
def _create_master_metadata(self, checkpoint_path: Path, metadata: dict[str, any]):
"""Create master metadata file (only on rank 0)
Args:
checkpoint_path (Path):
metadata (dict[str, any]):
"""
logger.info(f"Creating master metadata file...")
metadata_file = checkpoint_path / "checkpoint_metadata.json"
logger.info(f"Master metadata file path: {metadata_file}")
# Add information about all ranks
metadata['total_ranks'] = mpi_size
metadata['rank_files'] = [f"checkpoint_rank_{rank}.h5" for rank in range(mpi_size)]
logger.info(f"Added rank information - total_ranks: {mpi_size}, files: {len(metadata['rank_files'])}")
logger.info(f"Writing master metadata to JSON file...")
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2, default=str)
logger.info(f"Master metadata file created successfully - {len(metadata)} keys written")
def list_available_checkpoints(self, checkpoint_dir: Path) -> list[dict[str, any]]:
"""List all available checkpoints in a directory.
Args:
checkpoint_dir (Path): Directory to search for checkpoints
Returns:
list[dict[str, any]]: List of checkpoint information
"""
checkpoints = []
if not checkpoint_dir.exists():
return checkpoints
for item in checkpoint_dir.iterdir():
if not item.is_dir() or item.name.startswith('.'):
continue
metadata_file = item / "checkpoint_metadata.json"
# Skip directories that don't have checkpoint metadata
if not metadata_file.exists():
continue
with open(metadata_file, 'r') as f:
metadata = json.load(f)
metadata['path'] = str(item)
checkpoints.append(metadata)
# Sort by simulation time
checkpoints.sort(key=lambda x: x.get('simulation_time', 0))
return checkpoints
def get_latest_checkpoint(self, checkpoint_dir: Path) -> Optional[dict[str, any]]:
"""Get the most recent checkpoint.
Args:
checkpoint_dir (Path): Directory to search for checkpoints
Returns:
Optional[Dict[str, Any]]: Latest checkpoint metadata, or None if no checkpoints found
"""
checkpoints = self.list_available_checkpoints(checkpoint_dir)
return checkpoints[-1] if checkpoints else None
def reset_completed_checkpoint_dates_after_restoration(self, restored_simulation_time: float):
"""Reset completed checkpoint dates after restoration to allow future checkpoints.
After restoring from a checkpoint, we need to reset the tracking of completed
checkpoint dates so that any dates after the restored time can still be used
for creating new checkpoints.
Args:
restored_simulation_time (float): The simulation time that was restored from the checkpoint
"""
if self.checkpoint_dates is None:
# Not using specific dates mode, nothing to reset
return
# FORCE CLEAR ALL COMPLETED DATES FOR CHILD RUNS
# This ensures that child runs can create checkpoints at any configured date
original_completed = self.completed_checkpoint_dates.copy()
self.completed_checkpoint_dates = set() # Clear everything
# Log what was reset
if original_completed:
logger.info(f"FORCED RESET: Cleared all completed checkpoint dates after restoration from time {restored_simulation_time}")
logger.info(f" - Cleared dates: {sorted(original_completed)}")
logger.info(f" - All configured dates now available: {sorted(self.checkpoint_dates)}")
else:
logger.info(f"No checkpoint dates needed resetting after restoration from time {restored_simulation_time}")
# Also reset last checkpoint time if it's after the restored time
if self.last_checkpoint_time and self.last_checkpoint_time > restored_simulation_time:
logger.info(f"Reset last_checkpoint_time from {self.last_checkpoint_time} to None after restoration")
self.last_checkpoint_time = None
|