ML Collections
ML Collections is a library of Python Collections designed for ML use cases.
ConfigDict
The two classes called ConfigDict
and FrozenConfigDict
are "dict-like" data
structures with dot access to nested elements. Together, they are supposed to be
used as a main way of expressing configurations of experiments and models.
This document describes example usage of ConfigDict
, FrozenConfigDict
,
FieldReference
.
Features
- Dot-based access to fields.
- Locking mechanism to prevent spelling mistakes.
- Lazy computation.
- FrozenConfigDict() class which is immutable and hashable.
- Type safety.
- "Did you mean" functionality.
- Human readable printing (with valid references and cycles), using valid YAML
format.
- Fields can be passed as keyword arguments using the
**
operator. - There is one exception to the strong type-safety of the ConfigDict:
int
values can be passed in to fields of type float
. In such a case, the value
is type-converted to a float
before being stored. (Back in the day of
Python 2, there was a similar exception to allow both str
and unicode
values in string fields.)
Basic Usage
from ml_collections import config_dict
cfg = config_dict.ConfigDict()
cfg.float_field = 12.6
cfg.integer_field = 123
cfg.another_integer_field = 234
cfg.nested = config_dict.ConfigDict()
cfg.nested.string_field = 'tom'
print(cfg.integer_field)
print(cfg['integer_field'])
try:
cfg.integer_field = 'tom'
except TypeError as e:
print(e)
cfg.float_field = 12
cfg.nested.string_field = u'bob'
print(cfg)
FrozenConfigDict
A FrozenConfigDict
is an immutable, hashable type of ConfigDict
:
from ml_collections import config_dict
initial_dictionary = {
'int': 1,
'list': [1, 2],
'tuple': (1, 2, 3),
'set': {1, 2, 3, 4},
'dict_tuple_list': {'tuple_list': ([1, 2], 3)}
}
cfg = config_dict.ConfigDict(initial_dictionary)
frozen_dict = config_dict.FrozenConfigDict(initial_dictionary)
print(frozen_dict.tuple)
print(frozen_dict.list)
print(frozen_dict.set)
print(frozen_dict.dict_tuple_list.tuple_list[0])
frozen_cfg = config_dict.FrozenConfigDict(cfg)
print(frozen_cfg == frozen_dict)
print(hash(frozen_cfg) == hash(frozen_dict))
try:
frozen_dict.int = 2
except AttributeError as e:
print(e)
thawed_frozen_cfg = config_dict.ConfigDict(frozen_dict)
print(thawed_frozen_cfg == cfg)
frozen_cfg_to_cfg = frozen_dict.as_configdict()
print(frozen_cfg_to_cfg == cfg)
FieldReferences and placeholders
A FieldReference
is useful for having multiple fields use the same value. It
can also be used for lazy computation.
You can use placeholder()
as a shortcut to create a FieldReference
(field)
with a None
default value. This is useful if a program uses optional
configuration fields.
from ml_collections import config_dict
placeholder = config_dict.FieldReference(0)
cfg = config_dict.ConfigDict()
cfg.placeholder = placeholder
cfg.optional = config_dict.placeholder(int)
cfg.nested = config_dict.ConfigDict()
cfg.nested.placeholder = placeholder
try:
cfg.optional = 'tom'
except TypeError as e:
print(e)
cfg.optional = 1555
cfg.placeholder = 1
print(cfg)
Note that the indirection provided by FieldReference
s will be lost if accessed
through a ConfigDict
.
from ml_collections import config_dict
placeholder = config_dict.FieldReference(0)
cfg.field1 = placeholder
cfg.field2 = placeholder
cfg.field3 = cfg.field1
Lazy computation
Using a FieldReference
in a standard operation (addition, subtraction,
multiplication, etc...) will return another FieldReference
that points to the
original's value. You can use FieldReference.get()
to execute the operations
and get the reference's computed value, and FieldReference.set()
to change the
original reference's value.
from ml_collections import config_dict
ref = config_dict.FieldReference(1)
print(ref.get())
add_ten = ref.get() + 10
add_ten_lazy = ref + 10
print(add_ten)
print(add_ten_lazy.get())
ref.set(5)
print(add_ten)
print(add_ten_lazy.get())
If a FieldReference
has None
as its original value, or any operation has an
argument of None
, then the lazy computation will evaluate to None
.
We can also use fields in a ConfigDict
in lazy computation. In this case a
field will only be lazily evaluated if ConfigDict.get_ref()
is used to get it.
from ml_collections import config_dict
config = config_dict.ConfigDict()
config.reference_field = config_dict.FieldReference(1)
config.integer_field = 2
config.float_field = 2.5
config.no_lazy = config.integer_field * config.float_field
config.lazy_integer = config.get_ref('integer_field') * config.float_field
config.lazy_float = config.integer_field * config.get_ref('float_field')
config.lazy_both = (config.get_ref('integer_field') *
config.get_ref('float_field'))
config.integer_field = 3
print(config.no_lazy)
print(config.lazy_integer)
config.float_field = 3.5
print(config.lazy_float)
print(config.lazy_both)
Changing lazily computed values
Lazily computed values in a ConfigDict can be overridden in the same way as
regular values. The reference to the FieldReference
used for the lazy
computation will be lost and all computations downstream in the reference graph
will use the new value.
from ml_collections import config_dict
config = config_dict.ConfigDict()
config.reference = 1
config.reference_0 = config.get_ref('reference') + 10
config.reference_1 = config.get_ref('reference') + 20
config.reference_1_0 = config.get_ref('reference_1') + 100
print(config.reference)
print(config.reference_0)
print(config.reference_1)
print(config.reference_1_0)
config.reference_1 = 30
print(config.reference)
print(config.reference_0)
print(config.reference_1)
print(config.reference_1_0)
Cycles
You cannot create cycles using references. Fortunately
the only way to create a cycle is by
assigning a computed field to one that is not the result of computation. This
is forbidden:
from ml_collections import config_dict
config = config_dict.ConfigDict()
config.integer_field = 1
config.bigger_integer_field = config.get_ref('integer_field') + 10
try:
config.integer_field = config.get_ref('bigger_integer_field') + 2
except config_dict.MutabilityError as e:
print(e)
One-way references
One gotcha with get_ref
is that it creates a bi-directional dependency when no operations are performed on the value.
from ml_collections import config_dict
config = config_dict.ConfigDict()
config.reference = 1
config.reference_0 = config.get_ref('reference')
config.reference_0 = 2
print(config.reference)
print(config.reference_0)
This can be avoided by using get_oneway_ref
instead of get_ref
.
from ml_collections import config_dict
config = config_dict.ConfigDict()
config.reference = 1
config.reference_0 = config.get_oneway_ref('reference')
config.reference_0 = 2
print(config.reference)
print(config.reference_0)
Advanced usage
Here are some more advanced examples showing lazy computation with different
operators and data types.
from ml_collections import config_dict
config = config_dict.ConfigDict()
config.float_field = 12.6
config.integer_field = 123
config.list_field = [0, 1, 2]
config.float_multiply_field = config.get_ref('float_field') * 3
print(config.float_multiply_field)
config.float_field = 10.0
print(config.float_multiply_field)
config.longer_list_field = config.get_ref('list_field') + [3, 4, 5]
print(config.longer_list_field)
config.list_field = [-1]
print(config.longer_list_field)
config.ref_subtraction = (
config.get_ref('float_field') - config.get_ref('integer_field'))
print(config.ref_subtraction)
config.integer_field = 10
print(config.ref_subtraction)
Equality checking
You can use ==
and .eq_as_configdict()
to check equality among ConfigDict
and FrozenConfigDict
objects.
from ml_collections import config_dict
dict_1 = {'list': [1, 2]}
dict_2 = {'list': (1, 2)}
cfg_1 = config_dict.ConfigDict(dict_1)
frozen_cfg_1 = config_dict.FrozenConfigDict(dict_1)
frozen_cfg_2 = config_dict.FrozenConfigDict(dict_2)
print(frozen_cfg_1.items() == frozen_cfg_2.items())
print(frozen_cfg_1 == frozen_cfg_2)
print(frozen_cfg_1 == cfg_1)
print(frozen_cfg_1.eq_as_configdict(cfg_1))
print(cfg_1.eq_as_configdict(frozen_cfg_1))
Equality checking with lazy computation
Equality checks see if the computed values are the same. Equality is satisfied
if two sets of computations are different as long as they result in the same
value.
from ml_collections import config_dict
cfg_1 = config_dict.ConfigDict()
cfg_1.a = 1
cfg_1.b = cfg_1.get_ref('a') + 2
cfg_2 = config_dict.ConfigDict()
cfg_2.a = 1
cfg_2.b = cfg_2.get_ref('a') * 3
print(cfg_1 == cfg_2)
Locking and copying
Here is an example with lock()
and deepcopy()
:
import copy
from ml_collections import config_dict
cfg = config_dict.ConfigDict()
cfg.integer_field = 123
cfg.lock()
try:
cfg.intagar_field = 124
except AttributeError as e:
print(e)
with cfg.unlocked():
cfg.intagar_field = 1555
new_cfg = copy.deepcopy(cfg)
new_cfg.integer_field = -123
print(cfg)
print(new_cfg)
Output:
'Key "intagar_field" does not exist and cannot be added since the config is locked. Other fields present: "{\'integer_field\': 123}"\nDid you mean "integer_field" instead of "intagar_field"?'
intagar_field: 1555
integer_field: 123
intagar_field: 1555
integer_field: -123
Dictionary attributes and initialization
from ml_collections import config_dict
referenced_dict = {'inner_float': 3.14}
d = {
'referenced_dict_1': referenced_dict,
'referenced_dict_2': referenced_dict,
'list_containing_dict': [{'key': 'value'}],
}
cfg = config_dict.ConfigDict(d)
print(id(cfg.referenced_dict_1) == id(cfg.referenced_dict_2))
print(type(cfg.referenced_dict_1))
print(type(cfg.list_containing_dict[0]))
More Examples
For more examples, take a look at
ml_collections/config_dict/examples/
For examples and gotchas specifically about initializing a ConfigDict, see
ml_collections/config_dict/examples/config_dict_initialization.py
.
Config Flags
This library adds flag definitions to absl.flags
to handle config files. It
does not wrap absl.flags
so if using any standard flag definitions alongside
config file flags, users must also import absl.flags
.
Currently, this module adds two new flag types, namely DEFINE_config_file
which accepts a path to a Python file that generates a configuration, and
DEFINE_config_dict
which accepts a configuration directly. Configurations are
dict-like structures (see ConfigDict) whose nested elements
can be overridden using special command-line flags. See the examples below
for more details.
Usage
Use ml_collections.config_flags
alongside absl.flags
. For
example:
script.py
:
from absl import app
from absl import flags
from ml_collections import config_flags
_CONFIG = config_flags.DEFINE_config_file('my_config')
_MY_FLAG = flags.DEFINE_integer('my_flag', None)
def main(_):
print(_CONFIG.value)
print(_MY_FLAG.value)
if __name__ == '__main__':
app.run(main)
config.py
:
from ml_collections import config_dict
def get_config():
config = config_dict.ConfigDict()
config.field1 = 1
config.field2 = 'tom'
config.nested = config_dict.ConfigDict()
config.nested.field = 2.23
config.tuple = (1, 2, 3)
return config
Warning: If you are using a pickle-based distributed programming framework such
as Launchpad, be aware of
limitations on the structure of this script that are [described below]
(#config_files_and_pickling).
Now, after running:
python script.py --my_config=config.py \
--my_config.field1=8 \
--my_config.nested.field=2.1 \
--my_config.tuple='(1, 2, (1, 2))'
we get:
field1: 8
field2: tom
nested:
field: 2.1
tuple: !!python/tuple
- 1
- 2
- !!python/tuple
- 1
- 2
Usage of DEFINE_config_dict
is similar to DEFINE_config_file
, the main
difference is the configuration is defined in script.py
instead of in a
separate file.
script.py
:
from absl import app
from ml_collections import config_dict
from ml_collections import config_flags
config = config_dict.ConfigDict()
config.field1 = 1
config.field2 = 'tom'
config.nested = config_dict.ConfigDict()
config.nested.field = 2.23
config.tuple = (1, 2, 3)
_CONFIG = config_flags.DEFINE_config_dict('my_config', config)
def main(_):
print(_CONFIG.value)
if __name__ == '__main__':
app.run()
config_file
flags are compatible with the command-line flag syntax. All the
following options are supported for non-boolean values in configurations:
-(-)config.field=value
-(-)config.field value
Options for boolean values are slightly different:
-(-)config.boolean_field
: set boolean value to True.-(-)noconfig.boolean_field
: set boolean value to False.-(-)config.boolean_field=value
: value
is true
, false
, True
or
False
.
Note that -(-)config.boolean_field value
is not supported.
Parameterising the get_config() function
It's sometimes useful to be able to pass parameters into get_config
, and
change what is returned based on this configuration. One example is if you are
grid searching over parameters which have a different hierarchical structure -
the flag needs to be present in the resulting ConfigDict. It would be possible
to include the union of all possible leaf values in your ConfigDict,
but this produces a confusing config result as you have to remember which
parameters will actually have an effect and which won't.
A better system is to pass some configuration, indicating which structure of
ConfigDict should be returned. An example is the following config file:
from ml_collections import config_dict
def get_config(config_string):
possible_structures = {
'linear': config_dict.ConfigDict({
'model_constructor': 'snt.Linear',
'model_config': config_dict.ConfigDict({
'output_size': 42,
}),
'lstm': config_dict.ConfigDict({
'model_constructor': 'snt.LSTM',
'model_config': config_dict.ConfigDict({
'hidden_size': 108,
})
})
}
return possible_structures[config_string]
The value of config_string
will be anything that is to the right of the first
colon in the config file path, if one exists. If no colon exists, no value is
passed to get_config
(producing a TypeError if get_config
expects a value).
The above example can be run like:
python script.py -- --config=path_to_config.py:linear \
--config.model_config.output_size=256
or like:
python script.py -- --config=path_to_config.py:lstm \
--config.model_config.hidden_size=512
Additional features
- Loads any valid python script which defines
get_config()
function
returning any python object. - Automatic locking of the loaded object, if the loaded object defines a
callable
.lock()
method. - Supports command-line overriding of arbitrarily nested values in dict-like
objects (with key/attribute based getters/setters) of the following types:
int
float
bool
str
tuple
(but not list
)enum.Enum
- Overriding is type safe.
- Overriding of a
tuple
can be done by passing in the tuple
value as a
string (see the example in the Usage section). - The overriding
tuple
object can be of a different length and have
different item types than the original. Nested tuples are also supported.
Config Files and Pickling {#config_files_and_pickling}
This is likely to be troublesome:
@dataclasses.dataclass
class MyRecord:
num_balloons: int
color: str
def get_config():
return MyRecord(num_balloons=99, color='red')
This is not:
def get_config():
@dataclasses.dataclass
class MyRecord:
num_balloons: int
color: str
return MyRecord(num_balloons=99, color='red')
Explanation
A config file is a Python module but it is not imported through Python's usual
module-importing mechanism.
Meanwhile, serialization libraries such as cloudpickle
(which is used by Launchpad) and Apache Beam expect to be able to pickle an object without also
pickling every type to which it refers, on the assumption that types defined
at module scope can later be reconstructed simply by re-importing the modules
in which they are defined.
That assumption does not hold for a type that is defined at module scope in a
config file, because the config file can't be imported the usual way. The
symptom of this will be an ImportError
when unpickling an object.
The treatment is to move types from module scope into get_config()
so that
they will be serialized along with the values that have those types.
Authors