import time
from collections import OrderedDict
from functools import total_ordering
from plums.commons.path import Path
from plums.model.exception import PlumsModelMetadataValidationError
from plums.model.validation.metadata import Training as TrainingMetadata
from plums.model.components.version import version
from plums.model.components.utils import TrainingStatus, Checkpoint, is_duplicate
[docs]class Training(object):
"""Define a Python representation of a model training state.
Args:
start_time (float): The training stating time, or ``None`` if non existent
start_epoch (int): The training stating epoch, or ``None`` if non existent
latest_epoch (int): The training latest epoch, or ``None`` if non existent
latest_time (float): The training latest epoch time, or ``None`` if non existent
end_time (float): The training ending time, or ``None`` if non existent
end_epoch (int): The training ending epoch, or ``None`` if non existent
status (str): The training status, or ``'pending'`` if non existent
"""
def __init__(self, start_time=None, start_epoch=None, latest_time=None, latest_epoch=None,
end_time=None, end_epoch=None, status='pending'):
if latest_epoch is None:
if end_epoch is not None:
latest_epoch = end_epoch
else:
latest_epoch = start_epoch
if latest_time is None:
if end_time is not None:
latest_time = end_time
else:
latest_time = start_time
try:
TrainingMetadata(verbose=True).validate({
'status': status,
'start_time': start_time,
'start_epoch': start_epoch,
'latest_time': latest_time,
'latest_epoch': latest_epoch,
'end_time': end_time,
'end_epoch': end_epoch,
'latest': None,
'checkpoints': {}
})
except PlumsModelMetadataValidationError as e:
raise ValueError('Invalid training parameters: \n{}'.format(e.code))
self._start_epoch = start_epoch
self._start_timestamp = start_time
self._latest_epoch = latest_epoch
self._latest_timestamp = latest_time
self._end_epoch = end_epoch
self._end_timestamp = end_time
self._status = TrainingStatus(status)
@property
def status(self):
"""str: The training status, *i.e.* pending, running, failed or finished."""
return str(self._status)
@status.setter
def status(self, value):
self._status.status = value
@property
def start_epoch(self):
"""int: The training starting epoch if any, or ``None`` otherwise."""
return self._start_epoch
@property
def start_timestamp(self):
"""int: The training starting timestamp if any, or ``None`` otherwise."""
return self._start_timestamp
@property
def latest_epoch(self):
"""int: The training latest epoch if any, or ``None`` otherwise."""
return self._latest_epoch
@property
def latest_timestamp(self):
"""int: The training latest timestamp if any, or ``None`` otherwise."""
return self._latest_timestamp
@property
def end_epoch(self):
"""int: The training ending epoch if any, or ``None`` otherwise."""
return self._end_epoch
@property
def end_timestamp(self):
"""int: The training ending timestamp if any, or ``None`` otherwise."""
return self._end_timestamp
@property
def is_running(self):
"""bool: Whether the training is running."""
return self.status == 'running'
@property
def is_pending(self):
"""bool: Whether the training is pending."""
return self.status == 'pending'
@property
def is_finished(self):
"""bool: Whether the training is finished."""
return self.status == 'finished'
@property
def is_failed(self):
"""bool: Whether the training is failed."""
return self.status == 'failed'
def __repr__(self):
"""Represent a training."""
return '{}({})'.format(self.__class__.__name__, self.status)
[docs] def start(self, epoch):
"""Start the training and register the starting epoch and timestamp.
Args:
epoch (int): The starting epoch.
"""
self._status.status = 'running'
self._start_epoch = epoch
self._start_timestamp = time.time()
self._latest_epoch = epoch
self._latest_timestamp = self._start_timestamp
[docs] def interrupt(self):
"""Interrupt the training in a non-standard way and register the failing epoch and timestamp as the latest."""
self._status.status = 'failed'
self._end_epoch = self._latest_epoch
self._end_timestamp = self._latest_timestamp
[docs] def end(self):
"""Interrupt the training in a standard way and register the ending epoch and timestamp as the latest."""
self._status.status = 'finished'
self._end_epoch = self._latest_epoch
self._end_timestamp = self._latest_timestamp
[docs] def register_epoch(self, epoch=None):
"""Register a given epoch as being the latest epoch in the training along with its timestamp.
Args:
epoch (int): Optional. Default to :attr:`latest_epoch` + 1. The epoch to be registered as latest.
"""
if epoch is not None:
self._latest_epoch = max(self._latest_epoch, epoch)
else:
self._latest_epoch += 1
self._latest_timestamp = time.time()
[docs]class CheckpointCollection(object):
"""Define a checkpoint collection Python representation.
CheckpointCollection may be added through ``__setitem__`` or through :meth:`add`.
Args:
checkpoints (|Checkpoint|): A |Checkpoint| to be added to the collection.
"""
def __init__(self, *checkpoints):
self._checkpoints = OrderedDict()
self._references = set()
self._latest = None
if checkpoints:
for checkpoint in checkpoints:
self.add(checkpoint)
self._latest = checkpoints[-1].name
@property
def latest(self):
"""hashable: A reference to the latest |Checkpoint| added to the collection."""
return self._latest
@property
def eloc(self):
"""|CheckpointCollection|: Retrieve a |CheckpointCollection| from an epoch number."""
class _EpochIndexer(object):
def __init__(self, checkpoint_collection):
self._checkpoint_collection = checkpoint_collection
def __repr__(self):
"""Represent a checkpoint collection."""
return self._checkpoint_collection.__repr__()
def get_references_for_epoch(self, epoch):
reference_list = []
for ref, value in self._checkpoint_collection.items():
if value.epoch == epoch:
reference_list.append(ref)
return reference_list
def __getitem__(self, key):
references = self.get_references_for_epoch(key)
if not references:
raise IndexError('Invalid epoch number provided: {}'.format(key))
return CheckpointCollection(*(self._checkpoint_collection[reference] for reference in references))
return _EpochIndexer(self)
@property
def iloc(self):
"""|Checkpoint|, (|Checkpoint|, ): Retrieve a |Checkpoint| from its insertion index."""
class _IndexIndexer(object):
def __init__(self, checkpoint_collection):
self._checkpoint_collection = checkpoint_collection
self._index = tuple(self._checkpoint_collection.keys())
def __repr__(self):
"""Represent a checkpoint collection."""
return self._checkpoint_collection.__repr__()
def __getitem__(self, key):
if isinstance(key, slice):
return tuple(self._checkpoint_collection[reference] for reference in self._index[key])
return self._checkpoint_collection[self._index[key]]
return _IndexIndexer(self)
def __repr__(self):
"""Represent a checkpoint collection."""
return '{}({})'.format(self.__class__.__name__, tuple(checkpoint for ref, checkpoint in self.items()))
def __eq__(self, other):
"""Return whether two |CheckpointCollection| have the same set of |Checkpoint|.
Args:
other (|CheckpointCollection|): A |CheckpointCollection| to compare to.
Returns:
bool: ``True`` if both have the set of |Checkpoint|.
"""
try:
return set(self.items()) == set(other.items())
except AttributeError:
return NotImplemented
def __ne__(self, other):
"""Return whether two |CheckpointCollection| do not have the same set of |Checkpoint|.
Args:
other (|CheckpointCollection|): A |CheckpointCollection| to compare to.
Returns:
bool: ``True`` if none have the same set of |Checkpoint|.
"""
return not self == other
[docs] def add(self, checkpoint):
"""Add a |Checkpoint| to the collection.
Args:
checkpoint (|Checkpoint|): A |Checkpoint| to add to the collection.
Raises:
KeyError: If the |Checkpoint| name was already registered as a reference.
ValueError: If the |Checkpoint| epoch is ``None``.
"""
if checkpoint.name in self._references:
raise KeyError('Checkpoint name {} already registered: {}'.format(checkpoint.name,
self._checkpoints[checkpoint.name]))
if checkpoint.epoch is None:
raise ValueError('Checkpoint epoch is None in a CheckpointCollection.')
self._references.add(checkpoint.name)
self._checkpoints[checkpoint.name] = checkpoint
self._latest = checkpoint.name
def __getitem__(self, reference):
"""Retrieve a |Checkpoint| by its unique reference (its name).
Args:
reference (hashable): A |Checkpoint| unique reference.
Returns:
|Checkpoint|: The corresponding |Checkpoint|
Raises:
KeyError: If the given reference is not found in set of known references.
"""
return self._checkpoints[reference]
def __setitem__(self, reference, checkpoint):
"""Set a |Checkpoint| name to a new unique reference and either add it to the collection or replace the old one.
Args:
reference (hashable): A new |Checkpoint| unique reference.
checkpoint (|Checkpoint|): A |Checkpoint| to add to the collection.
Raises:
ValueError: If the |Checkpoint| epoch is ``None``.
"""
checkpoint = Checkpoint(name=reference, path=checkpoint.path, epoch=checkpoint.epoch, hash=checkpoint.hash)
try:
self.add(checkpoint)
except KeyError:
self._checkpoints[reference] = checkpoint
def __delitem__(self, reference):
"""Delete a |Checkpoint| by its unique reference (its name).
Args:
reference (hashable): A |Checkpoint| unique reference.
Raises:
KeyError: If the given reference is not found in set of known references.
"""
del self._checkpoints[reference]
self._references.discard(reference)
try:
self._latest = tuple(self._checkpoints.keys())[-1]
except IndexError:
self._latest = None
def __contains__(self, reference):
"""Return whether a given reference exists in the |CheckpointCollection|.
Args:
reference (hashable): A reference to assert existence on.
Returns:
bool: ``True`` if reference is registered in the collection.
"""
return reference in self._references
def __len__(self):
"""Return the number of registered references in the |CheckpointCollection|.
Returns:
int: The number of registered references in the collection.
"""
return len(self._references)
[docs] def get(self, reference, default=None):
"""Retrieve a |Checkpoint| by its unique reference (its name) or a ``default`` if reference is not found.
Args:
reference (hashable): A |Checkpoint| unique reference.
default (Any): The default to return if reference is not a valid |Checkpoint| reference.
Returns:
|Checkpoint|, Any: The corresponding |Checkpoint| or ``default``.
"""
element = self._checkpoints.get(reference)
return element if element is not None else default
[docs] def keys(self):
"""Iterate through the collection's references as in a dict, in insertion order.
Yields:
hashable: A reference of a |Checkpoint|.
"""
return self._checkpoints.keys()
[docs] def values(self):
"""Iterate through the collection's |Checkpoint| as in a dict, in insertion order.
Yields:
:class:`~plums.model.components.utils.Checkpoint`: A |Checkpoint|.
"""
return self._checkpoints.values()
[docs] def items(self):
"""Iterate through the collection as in a dict, in insertion order.
Yields:
(hashable, :class:`~plums.model.components.utils.Checkpoint`): Reference and |Checkpoint|.
"""
return self._checkpoints.items()
[docs]@total_ordering
class Producer(object):
"""Define a Python representation of a **PMF** model producer with its configuration.
Args:
name (str): The name of the package that produced the model.
version_format (str): The version *format* of the package that produced the model.
version_string (str): The version *representation string* of the package that produced the model.
configuration (Pathlike): A path to the producer configuration file used to produce the model.
See Also:
The |ProducerVersion| class for more information on the |Producer| version handling.
Attributes:
name (str): The name of the package that produced the model.
version (|ProducerVersion|): The version of the package that produced the model.
configuration (Pathlike): A path to the producer configuration file used to produce the model.
"""
def __init__(self, name, version_format, version_string, configuration):
if not Path(configuration).is_file():
raise OSError('Invalid configuration: {} is not a file.'.format(configuration))
self.name = name
self.version = version(version_format, version_string)
self.configuration = Path(configuration)
def __repr__(self):
"""Represent a producer."""
return '{}(name={}, version={})'.format(self.__class__.__name__, self.name, self.version.__repr__())
def __eq__(self, other):
"""Return if two |Producer| have the same :attr:`name` and :attr:`version`.
Args:
other (|Producer|): Another |Producer| to compare with.
Returns:
bool: ``True`` if the two |Producer| are identical in name and |ProducerVersion|.
"""
try:
return self.name == other.name and self.version == other.version
except AttributeError:
return NotImplemented
def __ne__(self, other):
"""Return if two |Producer| do not have the same :attr:`name` and :attr:`version`.
Args:
other (|Producer|): Another |Producer| to compare with.
Returns:
bool: ``True`` if the two |Producer| are not identical in name and |ProducerVersion|.
"""
return not self == other
def __lt__(self, other):
"""If two |Producer| have the same :attr:`name`, return their :attr:`version` order.
Args:
other (|Producer|): Another |Producer| to compare with.
Returns:
bool: ``True`` if the two |Producer| are identical in name, and if ``other`` |ProducerVersion| is less than
self :attr:`version`.
"""
try:
return self.name == other.name and self.version < other.version
except AttributeError:
return NotImplemented
[docs] def strict_equals(self, other):
"""Return if two |Producer| have the same :attr:`name` and :attr:`version` and :attr:`configuration` file.
Args:
other (|Producer|): Another |Producer| to compare with.
Returns:
bool: ``True`` if the two |Producer| are identical in name, |ProducerVersion| and configuration.
"""
try:
return self == other and is_duplicate(str(self.configuration), str(other.configuration))
except AttributeError:
return NotImplemented