diff --git a/CHANGELOG.md b/CHANGELOG.md index 1408ee4f1..b7cf52042 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,6 +70,7 @@ To release a new version, please update the changelog as followed: ## [Unreleased] ### Added +- Support nested layer customization (#PR 1015) ### Changed @@ -83,7 +84,7 @@ To release a new version, please update the changelog as followed: ### Fixed - Fix `tf.models.Model._construct_graph` for list of outputs, e.g. STN case (PR #1010) - +- Enable better `in_channels` exception raise. (pR #1015) ### Removed ### Security @@ -91,7 +92,7 @@ To release a new version, please update the changelog as followed: ### Contributors - @zsdonghao -- @ChrisWu1997: #1010 +- @ChrisWu1997: #1010 #1015 ## [2.1.0] diff --git a/examples/database/task_script.py b/examples/database/task_script.py index 2076847dd..3d77102b1 100644 --- a/examples/database/task_script.py +++ b/examples/database/task_script.py @@ -12,6 +12,7 @@ # load dataset from database X_train, y_train, X_val, y_val, X_test, y_test = db.find_top_dataset('mnist') + # define the network def mlp(): ni = tl.layers.Input([None, 784], name='input') @@ -24,15 +25,18 @@ def mlp(): M = tl.models.Model(inputs=ni, outputs=net) return M + network = mlp() # cost and accuracy cost = tl.cost.cross_entropy + def acc(y, y_): correct_prediction = tf.equal(tf.argmax(y, 1), tf.convert_to_tensor(y_, tf.int64)) return tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + # define the optimizer train_op = tf.optimizers.Adam(learning_rate=0.0001) @@ -43,8 +47,17 @@ def acc(y, y_): # ) tl.utils.fit( - network, train_op=tf.optimizers.Adam(learning_rate=0.0001), cost=tl.cost.cross_entropy, X_train=X_train, - y_train=y_train, acc=acc, batch_size=256, n_epoch=20, X_val=X_val, y_val=y_val, eval_train=False, + network, + train_op=tf.optimizers.Adam(learning_rate=0.0001), + cost=tl.cost.cross_entropy, + X_train=X_train, + y_train=y_train, + acc=acc, + batch_size=256, + n_epoch=20, + X_val=X_val, + y_val=y_val, + eval_train=False, ) # evaluation and save result that match the result_key @@ -55,5 +68,3 @@ def acc(y, y_): db.save_model(network, model_name='mlp', name=str(n_units1) + '-' + str(n_units2), test_accuracy=test_accuracy) # in other script, you can load the model as follow # net = db.find_model(sess=sess, model_name=str(n_units1)+'-'+str(n_units2) - -tf.python.keras.layers.BatchNormalization \ No newline at end of file diff --git a/tensorlayer/layers/core.py b/tensorlayer/layers/core.py index 7049e216c..6fcbd4aa8 100644 --- a/tensorlayer/layers/core.py +++ b/tensorlayer/layers/core.py @@ -127,8 +127,11 @@ def __init__(self, name=None, act=None, *args, **kwargs): # Layer weight state self._all_weights = None - self._trainable_weights = None - self._nontrainable_weights = None + self._trainable_weights = [] + self._nontrainable_weights = [] + + # nested layers + self._layers = None # Layer training state self.is_train = True @@ -179,20 +182,18 @@ def all_weights(self): if self._all_weights is not None and len(self._all_weights) > 0: pass else: - self._all_weights = list() - if self._trainable_weights is not None: - self._all_weights.extend(self._trainable_weights) - if self._nontrainable_weights is not None: - self._all_weights.extend(self._nontrainable_weights) + self._all_weights = self.trainable_weights + self.nontrainable_weights return self._all_weights @property def trainable_weights(self): - return self._trainable_weights + nested = self._collect_sublayers_attr('trainable_weights') + return self._trainable_weights + nested @property def nontrainable_weights(self): - return self._nontrainable_weights + nested = self._collect_sublayers_attr('nontrainable_weights') + return self._nontrainable_weights + nested @property def weights(self): @@ -200,6 +201,21 @@ def weights(self): "no property .weights exists, do you mean .all_weights, .trainable_weights, or .nontrainable_weights ?" ) + def _collect_sublayers_attr(self, attr): + if attr not in ['trainable_weights', 'nontrainable_weights']: + raise ValueError( + "Only support to collect some certain attributes of nested layers," + "e.g. 'trainable_weights', 'nontrainable_weights', but got {}".format(attr) + ) + if self._layers is None: + return [] + nested = [] + for layer in self._layers: + value = getattr(layer, attr) + if value is not None: + nested.extend(value) + return nested + def __call__(self, inputs, *args, **kwargs): """ (1) Build the Layer if necessary. @@ -326,6 +342,20 @@ def __setitem__(self, key, item): def __delitem__(self, key): raise TypeError("The Layer API does not allow to use the method: `__delitem__`") + def __setattr__(self, key, value): + if isinstance(value, Layer): + value._nodes_fixed = True + if self._layers is None: + self._layers = [] + self._layers.append(value) + super().__setattr__(key, value) + + def __delattr__(self, name): + value = getattr(self, name, None) + if isinstance(value, Layer): + self._layers.remove(value) + super().__delattr__(name) + @protected_method def get_args(self): init_args = {"layer_type": "normal"} diff --git a/tensorlayer/models/core.py b/tensorlayer/models/core.py index 52d8dc083..2d62d2b56 100644 --- a/tensorlayer/models/core.py +++ b/tensorlayer/models/core.py @@ -591,6 +591,15 @@ def _fix_nodes_for_layers(self): layer._fix_nodes_for_layers() self._nodes_fixed = True + def __setattr__(self, key, value): + if isinstance(value, Layer): + if value._built is False: + raise AttributeError( + "The registered layer `{}` should be built in advance. " + "Do you forget to pass the keyword argument 'in_channels'? ".format(value.name) + ) + super().__setattr__(key, value) + def __repr__(self): # tmpstr = self.__class__.__name__ + '(\n' tmpstr = self.name + '(\n' diff --git a/tests/layers/test_layers_core_nested.py b/tests/layers/test_layers_core_nested.py new file mode 100644 index 000000000..e44c12f3a --- /dev/null +++ b/tests/layers/test_layers_core_nested.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-\ +import os +import unittest + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +import tensorflow as tf +import tensorlayer as tl +import numpy as np + +from tests.utils import CustomTestCase + + +class Layer_nested(CustomTestCase): + + @classmethod + def setUpClass(cls): + print("##### begin testing nested layer #####") + + @classmethod + def tearDownClass(cls): + pass + # tf.reset_default_graph() + + def test_nested_layer_with_inchannels(cls): + + class MyLayer(tl.layers.Layer): + + def __init__(self, name=None): + super(MyLayer, self).__init__(name=name) + self.input_layer = tl.layers.Dense(in_channels=50, n_units=20) + self.build(None) + self._built = True + + def build(self, inputs_shape=None): + self.W = self._get_weights('weights', shape=(20, 10)) + + def forward(self, inputs): + inputs = self.input_layer(inputs) + output = tf.matmul(inputs, self.W) + return output + + class model(tl.models.Model): + + def __init__(self, name=None): + super(model, self).__init__(name=name) + self.layer = MyLayer() + + def forward(self, inputs): + return self.layer(inputs) + + input = tf.random.normal(shape=(100, 50)) + model_dynamic = model() + model_dynamic.train() + cls.assertEqual(model_dynamic(input).shape, (100, 10)) + cls.assertEqual(len(model_dynamic.all_weights), 3) + cls.assertEqual(len(model_dynamic.trainable_weights), 3) + model_dynamic.layer.input_layer.b.assign_add(tf.ones((20, ))) + cls.assertEqual(np.sum(model_dynamic.all_weights[-1].numpy() - tf.ones(20, ).numpy()), 0) + + ni = tl.layers.Input(shape=(100, 50)) + nn = MyLayer(name='mylayer1')(ni) + model_static = tl.models.Model(inputs=ni, outputs=nn) + model_static.eval() + cls.assertEqual(model_static(input).shape, (100, 10)) + cls.assertEqual(len(model_static.all_weights), 3) + cls.assertEqual(len(model_static.trainable_weights), 3) + model_static.get_layer('mylayer1').input_layer.b.assign_add(tf.ones((20, ))) + cls.assertEqual(np.sum(model_static.all_weights[-1].numpy() - tf.ones(20, ).numpy()), 0) + + def test_nested_layer_without_inchannels(cls): + + class MyLayer(tl.layers.Layer): + + def __init__(self, name=None): + super(MyLayer, self).__init__(name=name) + self.input_layer = tl.layers.Dense(n_units=20) # no need for in_channels here + self.build(None) + self._built = True + + def build(self, inputs_shape=None): + self.W = self._get_weights('weights', shape=(20, 10)) + + def forward(self, inputs): + inputs = self.input_layer(inputs) + output = tf.matmul(inputs, self.W) + return output + + class model(tl.models.Model): + + def __init__(self, name=None): + super(model, self).__init__(name=name) + self.layer = MyLayer() + + def forward(self, inputs): + return self.layer(inputs) + + input = tf.random.normal(shape=(100, 50)) + model_dynamic = model() + model_dynamic.train() + cls.assertEqual(model_dynamic(input).shape, (100, 10)) + cls.assertEqual(len(model_dynamic.all_weights), 3) + cls.assertEqual(len(model_dynamic.trainable_weights), 3) + model_dynamic.layer.input_layer.b.assign_add(tf.ones((20, ))) + cls.assertEqual(np.sum(model_dynamic.all_weights[-1].numpy() - tf.ones(20, ).numpy()), 0) + + ni = tl.layers.Input(shape=(100, 50)) + nn = MyLayer(name='mylayer2')(ni) + model_static = tl.models.Model(inputs=ni, outputs=nn) + model_static.eval() + cls.assertEqual(model_static(input).shape, (100, 10)) + cls.assertEqual(len(model_static.all_weights), 3) + cls.assertEqual(len(model_static.trainable_weights), 3) + model_static.get_layer('mylayer2').input_layer.b.assign_add(tf.ones((20, ))) + cls.assertEqual(np.sum(model_static.all_weights[-1].numpy() - tf.ones(20, ).numpy()), 0) + + +if __name__ == '__main__': + + tl.logging.set_verbosity(tl.logging.DEBUG) + + unittest.main() diff --git a/tests/models/test_model_core.py b/tests/models/test_model_core.py index 28e57b0f2..3db470f9d 100644 --- a/tests/models/test_model_core.py +++ b/tests/models/test_model_core.py @@ -401,6 +401,25 @@ def test_model_weights_copy(self): new_len = len(model_weights) self.assertEqual(new_len - 1, ori_len) + def test_inchannels_exception(self): + print('-' * 20, 'test_inchannels_exception', '-' * 20) + + class my_model(Model): + + def __init__(self): + super(my_model, self).__init__() + self.dense = Dense(64) + self.vgg = tl.models.vgg16() + + def forward(self, x): + return x + + try: + M = my_model() + except Exception as e: + self.assertIsInstance(e, AttributeError) + print(e) + if __name__ == '__main__':