Skip to content

Random state manager

Dedicated Random State Manager for JUNE Checkpointing

This module provides exact random state preservation and restoration, ensuring bit-for-bit reproducibility across checkpoint/restore cycles.

RandomStateManager

Manages random number generator states for exact reproducibility.

Uses dedicated pickle files to preserve exact random state formats, avoiding the data corruption issues of HDF5 type conversion.

Source code in june/checkpointing/random_state_manager.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
class RandomStateManager:
    """Manages random number generator states for exact reproducibility.

    Uses dedicated pickle files to preserve exact random state formats,
    avoiding the data corruption issues of HDF5 type conversion.

    """

    def __init__(self):
        """Initialise the random state manager."""
        self.supported_generators = ['python_random', 'numpy_random']

    def capture_states(self) -> Dict[str, Any]:
        """Capture current random number generator states.


        Returns:
            Dict[str, Any]: Dictionary containing all captured random states

        """
        logger.debug(f"Rank {mpi_rank}: Capturing random states")

        states = {}

        # Capture Python random state
        states['python_random'] = random.getstate()
        logger.debug(f"Rank {mpi_rank}: Captured Python random state (type: {type(states['python_random'])})")

        # Capture NumPy random state  
        states['numpy_random'] = np.random.get_state()
        logger.debug(f"Rank {mpi_rank}: Captured NumPy random state (type: {type(states['numpy_random'])})")

        # Add metadata
        states['metadata'] = {
            'mpi_rank': mpi_rank,
            'python_version': random.getstate()[0],  # Version from Python state
            'numpy_generator': np.random.get_state()[0],  # Generator name from NumPy state
            'capture_success': True
        }

        logger.info(f"Rank {mpi_rank}: Successfully captured all random states")
        return states

    def save_states(self, states: Dict[str, Any], checkpoint_dir: Path) -> bool:
        """Save random states to dedicated pickle files.

        Args:
            states (Dict[str, Any]): Random states to save
            checkpoint_dir (Path): Directory to save states in

        Returns:
            bool: True if successful

        """
        # Create random states subdirectory
        random_states_dir = checkpoint_dir / "random_states"
        random_states_dir.mkdir(exist_ok=True)

        # Save each rank's states to separate file
        rank_file = random_states_dir / f"random_states_rank_{mpi_rank}.pkl"

        with open(rank_file, 'wb') as f:
            pickle.dump(states, f, protocol=pickle.HIGHEST_PROTOCOL)

        logger.info(f"Rank {mpi_rank}: Saved random states to {rank_file}")

        # Verify the save by attempting to load
        with open(rank_file, 'rb') as f:
            test_load = pickle.load(f)

        # Basic verification
        if not isinstance(test_load, dict) or 'python_random' not in test_load:
            raise ValueError("Verification failed: saved data is corrupted")

        logger.debug(f"Rank {mpi_rank}: Random state save verification passed")
        return True

    def load_states(self, checkpoint_dir: Path) -> Optional[Dict[str, Any]]:
        """Load random states from pickle files.

        Args:
            checkpoint_dir (Path): Directory containing saved states

        Returns:
            Optional[Dict[str, Any]]: Loaded random states, or None if failed

        """

        random_states_dir = checkpoint_dir / "random_states"
        rank_file = random_states_dir / f"random_states_rank_{mpi_rank}.pkl"

        if not rank_file.exists():
            logger.error(f"Rank {mpi_rank}: Random states file not found: {rank_file}")
            return None

        with open(rank_file, 'rb') as f:
            states = pickle.load(f)

        logger.info(f"Rank {mpi_rank}: Loaded random states from {rank_file}")

        # Validate loaded data
        if not self._validate_states(states):               
            logger.error(f"Rank {mpi_rank}: Loaded random states failed validation")
            return None

        return states

    def restore_states(self, states: Dict[str, Any]) -> bool:
        """Restore random number generator states.

        Args:
            states (Dict[str, Any]): Random states to restore

        Returns:
            bool: True if successful

        """

        # Restore Python random state
        if 'python_random' in states:
            python_state = states['python_random']
            random.setstate(python_state)
            logger.debug(f"Rank {mpi_rank}: Restored Python random state")
        else:
            logger.error(f"Rank {mpi_rank}: No Python random state found")
            return False

        # Restore NumPy random state
        if 'numpy_random' in states:
            numpy_state = states['numpy_random']
            np.random.set_state(numpy_state)
            logger.debug(f"Rank {mpi_rank}: Restored NumPy random state")
        else:
            logger.error(f"Rank {mpi_rank}: No NumPy random state found")
            return False

        logger.info(f"Rank {mpi_rank}: Successfully restored all random states")
        return True

    def _validate_states(self, states: Dict[str, Any]) -> bool:
        """Validate loaded random states.

        Args:
            states (Dict[str, Any]): States to validate

        Returns:
            bool: True if valid

        """

        # Check required keys
        required_keys = ['python_random', 'numpy_random', 'metadata']
        for key in required_keys:
            if key not in states:
                logger.error(f"Rank {mpi_rank}: Missing required key: {key}")
                return False

        # Validate Python random state format
        python_state = states['python_random']
        if not isinstance(python_state, tuple) or len(python_state) != 3:
            logger.error(f"Rank {mpi_rank}: Invalid Python random state format")
            return False

        # Validate NumPy random state format  
        numpy_state = states['numpy_random']
        if not isinstance(numpy_state, tuple) or len(numpy_state) != 5:
            logger.error(f"Rank {mpi_rank}: Invalid NumPy random state format")
            return False

        # Validate metadata
        metadata = states['metadata']
        if not isinstance(metadata, dict) or 'mpi_rank' not in metadata:
            logger.error(f"Rank {mpi_rank}: Invalid metadata format")
            return False

        logger.debug(f"Rank {mpi_rank}: Random state validation passed")
        return True

__init__()

Initialise the random state manager.

Source code in june/checkpointing/random_state_manager.py
28
29
30
def __init__(self):
    """Initialise the random state manager."""
    self.supported_generators = ['python_random', 'numpy_random']

capture_states()

Capture current random number generator states.

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: Dictionary containing all captured random states

Source code in june/checkpointing/random_state_manager.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def capture_states(self) -> Dict[str, Any]:
    """Capture current random number generator states.


    Returns:
        Dict[str, Any]: Dictionary containing all captured random states

    """
    logger.debug(f"Rank {mpi_rank}: Capturing random states")

    states = {}

    # Capture Python random state
    states['python_random'] = random.getstate()
    logger.debug(f"Rank {mpi_rank}: Captured Python random state (type: {type(states['python_random'])})")

    # Capture NumPy random state  
    states['numpy_random'] = np.random.get_state()
    logger.debug(f"Rank {mpi_rank}: Captured NumPy random state (type: {type(states['numpy_random'])})")

    # Add metadata
    states['metadata'] = {
        'mpi_rank': mpi_rank,
        'python_version': random.getstate()[0],  # Version from Python state
        'numpy_generator': np.random.get_state()[0],  # Generator name from NumPy state
        'capture_success': True
    }

    logger.info(f"Rank {mpi_rank}: Successfully captured all random states")
    return states

load_states(checkpoint_dir)

Load random states from pickle files.

Parameters:

Name Type Description Default
checkpoint_dir Path

Directory containing saved states

required

Returns:

Type Description
Optional[Dict[str, Any]]

Optional[Dict[str, Any]]: Loaded random states, or None if failed

Source code in june/checkpointing/random_state_manager.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def load_states(self, checkpoint_dir: Path) -> Optional[Dict[str, Any]]:
    """Load random states from pickle files.

    Args:
        checkpoint_dir (Path): Directory containing saved states

    Returns:
        Optional[Dict[str, Any]]: Loaded random states, or None if failed

    """

    random_states_dir = checkpoint_dir / "random_states"
    rank_file = random_states_dir / f"random_states_rank_{mpi_rank}.pkl"

    if not rank_file.exists():
        logger.error(f"Rank {mpi_rank}: Random states file not found: {rank_file}")
        return None

    with open(rank_file, 'rb') as f:
        states = pickle.load(f)

    logger.info(f"Rank {mpi_rank}: Loaded random states from {rank_file}")

    # Validate loaded data
    if not self._validate_states(states):               
        logger.error(f"Rank {mpi_rank}: Loaded random states failed validation")
        return None

    return states

restore_states(states)

Restore random number generator states.

Parameters:

Name Type Description Default
states Dict[str, Any]

Random states to restore

required

Returns:

Name Type Description
bool bool

True if successful

Source code in june/checkpointing/random_state_manager.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def restore_states(self, states: Dict[str, Any]) -> bool:
    """Restore random number generator states.

    Args:
        states (Dict[str, Any]): Random states to restore

    Returns:
        bool: True if successful

    """

    # Restore Python random state
    if 'python_random' in states:
        python_state = states['python_random']
        random.setstate(python_state)
        logger.debug(f"Rank {mpi_rank}: Restored Python random state")
    else:
        logger.error(f"Rank {mpi_rank}: No Python random state found")
        return False

    # Restore NumPy random state
    if 'numpy_random' in states:
        numpy_state = states['numpy_random']
        np.random.set_state(numpy_state)
        logger.debug(f"Rank {mpi_rank}: Restored NumPy random state")
    else:
        logger.error(f"Rank {mpi_rank}: No NumPy random state found")
        return False

    logger.info(f"Rank {mpi_rank}: Successfully restored all random states")
    return True

save_states(states, checkpoint_dir)

Save random states to dedicated pickle files.

Parameters:

Name Type Description Default
states Dict[str, Any]

Random states to save

required
checkpoint_dir Path

Directory to save states in

required

Returns:

Name Type Description
bool bool

True if successful

Source code in june/checkpointing/random_state_manager.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def save_states(self, states: Dict[str, Any], checkpoint_dir: Path) -> bool:
    """Save random states to dedicated pickle files.

    Args:
        states (Dict[str, Any]): Random states to save
        checkpoint_dir (Path): Directory to save states in

    Returns:
        bool: True if successful

    """
    # Create random states subdirectory
    random_states_dir = checkpoint_dir / "random_states"
    random_states_dir.mkdir(exist_ok=True)

    # Save each rank's states to separate file
    rank_file = random_states_dir / f"random_states_rank_{mpi_rank}.pkl"

    with open(rank_file, 'wb') as f:
        pickle.dump(states, f, protocol=pickle.HIGHEST_PROTOCOL)

    logger.info(f"Rank {mpi_rank}: Saved random states to {rank_file}")

    # Verify the save by attempting to load
    with open(rank_file, 'rb') as f:
        test_load = pickle.load(f)

    # Basic verification
    if not isinstance(test_load, dict) or 'python_random' not in test_load:
        raise ValueError("Verification failed: saved data is corrupted")

    logger.debug(f"Rank {mpi_rank}: Random state save verification passed")
    return True