Introduction
One of the nicest features of the GPFlow software GPFlow is the ability to create an instance of your model, and then set certain parameters to be fixed during all or part of training. Typically this can be done in a line of code with something like
gpflow.utilities.set_trainable(model.param, False)
In contrast the more functional nature of models in Jax/Flax
mean that we shift away from considering a single instantiated
instance of our model object, and instead most of the work
is done by composable functional layers in the nn.Module
class,
and these are lightly wrapped inside of a nn.Model
class.
The model class behaves like frozen dataclass
, so that any
time we watn to modify a parameter we actually have to create a
new instance of the model with the unmodified parameters copied over.
In the following we briefly sketch a functional approach to recreating the general idea of fixing some parameters and not modifying them during training using models constructed in Flax.
Fixing the parameters of a flax.nn.Module
First lets create a simple
multiplication layer with parameter a
:
class Layer(flax.nn.Module): def apply(self, x, a_fixed=False, a_init=jax.nn.initializers.ones): a = self.param('a', (1, ), a_init) if a_fixed: _default_key = random.PRNGKey(0) a = a_init(_default_key, (1, )) return a * x
Note that this code always initalizes the parameter a
, but if we
also pass the argument a_fixed
then the parameter never reaches the
return. This has the advantage of allowing for a
to still be in our
model.params
, and therefore we can for example modify the fixed
parameters using the replace
method, but it means that in any gradient
based training methods the parameter will never reach the tape. There
is a cost of this disconnect which we will return to later.
Now we will imagine this is called inside of some larger model
class MyModel(flax.nn.Module): def apply(self, x, **kwargs): x = Layer(x, **kwargs.get('layer_kwargs', {}), name='layer') return x
Of course keeping track of all of these kwargs
every single time we
make use of a model.apply(...)
is going to be a real headache!
One method of handling this is to use the .partial
method of the
nn.flax.Module
, as demonstrated over at the
flax docs.
Using this we now just have a one time burden of specifying a larger set of
kwargs
, then using the partial method to create a new definition of our model,
and then optionally writing a model creation method that is aware of these
different definitions. The result is something like
free_kwargs = {'layer_kwargs': {'a_fixed': False}} fixed_kwargs = {'layer_kwargs': {'a_fixed': True}} # use partial to fix the initial functions free_model_def = MyModel.partial(**free_kwargs) fixed_model_def = MyModel.partial(**fixed_kwargs) def create_model(model_def, key, input_specs): x, init_params = model_def.init_by_shape(key, input_specs) return flax.nn.Model(model_def, init_params)
As we can see the fixed model still has a
as a parameter
>>> fixed_model.params {'layer': {'a': DeviceArray([1.], dtype=float32}}
Now if we define a test function and take gradients we find
def loss_fn(model): x = jnp.ones(*input_shape_and_dtype[0]) y = model(x) return jnp.mean(y ** 2) free_model_grad = jax.grad(loss_fn)(free_model) fixed_model_grad = jax.grad(loss_fn)(fixed_model) assert(loss_fn(free_model) == loss_fn(fixed_model)) assert(free_model_grad.params['layer']['a'] == 2) assert(fixed_model_grad.params['layer']['a'] == 0)
Big caveat
Some of you may have noticed the issue which will occur in the following use case
new_fixed_model = fixed_model.replace( params={'layer': {'a': 3.14*jnp.ones([1])}}) assert(new_fixed_model(x) != fixed_model)
will raise an AssertionError
! Because of our earlier comment
about the disconnect between the parameter and the layer output.
Instead what we want is actually a new model definition
new_fixed_kwargs = { 'layer_kwargs': {'a_fixed': True, 'a_init': lambda key, shape: 3.14*jnp.ones([1])}}
we will return to this point in part ii of this series.
Further comments and ToDos
In contrast to gpflow we are not able to take an existing model instance, and change a previously
free parameter to a fixed one – ultimately it doesn’t seem intended behaviour in flax
for an instance to persist over long periods of the implementation. Instead we can quote the docs
A model instance is callable and functional (e.g. changing parameters requires a new model instance.)
This takes some getting used to at first if you are coming from a more object orientated perspective,
and it does make the process of model creation feel quite different. While it would also be possible to
wrap this inside of more syntactic sugar, in the way that the current GPFlow parameters are themselves
wrapping the base Module
class with its variable tracking in Tensorflow
, it is probably best not to
do this. Instead it seems like it would probably be preferable to have the user need to write potentially
quite a large number of lines to make their model behave the way they want to, so long as each of these
lines was simple and their composition intuitive and easy to follow.
In the next post we will show how to better adapt this method to handle changing the fixed parameter, and then in the next few weeks we will demonstrate some examples of various models coded in Flax compared to their GPFlow equivalents to give some sense of how these different approaches manifest for the user/model builder.