阅读量:0
在Keras中实现Capsule网络可以通过使用keras.layers
中的Capsule
和PrimaryCap
层来实现。下面是一个简单的示例:
from keras import layers from keras.models import Model # 定义Capsule网络架构 def CapsuleModel(input_shape, n_class, routings): x = layers.Input(shape=input_shape) # 定义第一个Capsule层 conv1 = layers.Conv2D(128, (9, 9), activation='relu', padding='valid')(x) primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid') # 定义第二个Capsule层 digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings)(primarycaps) # 输出分类结果 out_caps = Length()(digitcaps) return Model(x, out_caps) # 定义PrimaryCapsule层 def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding): output = layers.Conv2D(filters=dim_capsule*n_channels, kernel_size=kernel_size, strides=strides, padding=padding)(inputs) outputs = layers.Reshape(target_shape=(-1, dim_capsule))(output) return layers.Lambda(lambda x: x / K.sqrt(K.sum(K.square(x), axis=-1, keepdims=True)))(outputs) # 定义Capsule层 class CapsuleLayer(layers.Layer): def __init__(self, num_capsule, dim_capsule, routings=3, kernel_initializer='glorot_uniform', **kwargs): super(CapsuleLayer, self).__init__(**kwargs) self.num_capsule = num_capsule self.dim_capsule = dim_capsule self.routings = routings self.kernel_initializer = initializers.get(kernel_initializer) def build(self, input_shape): input_dim_capsule = input_shape[-1] self.W = self.add_weight(shape=[input_dim_capsule, self.num_capsule * self.dim_capsule], initializer=self.kernel_initializer, name='W') self.built = True def call(self, inputs): inputs_expand = K.expand_dims(inputs, 2) inputs_tiled = K.tile(inputs_expand, [1, 1, self.num_capsule, 1]) inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 1]), elems=inputs_tiled) b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, K.shape(inputs_hat)[2]]) assert self.routings > 0 for i in range(self.routings): c = tf.nn.softmax(b, dim=1) outputs = squash(K.batch_dot(c, inputs_hat, [2, 2])) if i < self.routings - 1: b += K.batch_dot(outputs, inputs_hat, [2, 3]) return outputs def compute_output_shape(self, input_shape): return input_shape # 定义Length层 class Length(layers.Layer): def call(self, inputs, **kwargs): return K.sqrt(K.sum(K.square(inputs), -1)) def compute_output_shape(self, input_shape): return input_shape[:-1] # 构建Capsule网络模型 model = CapsuleModel((28, 28, 1), 10, 3) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])