Keras中如何使用Capsule网络

avatar
作者
筋斗云
阅读量:0

在Keras中实现Capsule网络可以通过使用keras.layers中的CapsulePrimaryCap层来实现。下面是一个简单的示例:

from keras import layers from keras.models import Model  # 定义Capsule网络架构 def CapsuleModel(input_shape, n_class, routings):     x = layers.Input(shape=input_shape)      # 定义第一个Capsule层     conv1 = layers.Conv2D(128, (9, 9), activation='relu', padding='valid')(x)     primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')      # 定义第二个Capsule层     digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings)(primarycaps)      # 输出分类结果     out_caps = Length()(digitcaps)          return Model(x, out_caps)  # 定义PrimaryCapsule层 def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding):     output = layers.Conv2D(filters=dim_capsule*n_channels, kernel_size=kernel_size, strides=strides, padding=padding)(inputs)     outputs = layers.Reshape(target_shape=(-1, dim_capsule))(output)     return layers.Lambda(lambda x: x / K.sqrt(K.sum(K.square(x), axis=-1, keepdims=True)))(outputs)  # 定义Capsule层 class CapsuleLayer(layers.Layer):     def __init__(self, num_capsule, dim_capsule, routings=3, kernel_initializer='glorot_uniform', **kwargs):         super(CapsuleLayer, self).__init__(**kwargs)         self.num_capsule = num_capsule         self.dim_capsule = dim_capsule         self.routings = routings         self.kernel_initializer = initializers.get(kernel_initializer)      def build(self, input_shape):         input_dim_capsule = input_shape[-1]         self.W = self.add_weight(shape=[input_dim_capsule, self.num_capsule * self.dim_capsule], initializer=self.kernel_initializer, name='W')         self.built = True      def call(self, inputs):         inputs_expand = K.expand_dims(inputs, 2)         inputs_tiled = K.tile(inputs_expand, [1, 1, self.num_capsule, 1])         inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 1]), elems=inputs_tiled)         b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, K.shape(inputs_hat)[2]])         assert self.routings > 0         for i in range(self.routings):             c = tf.nn.softmax(b, dim=1)             outputs = squash(K.batch_dot(c, inputs_hat, [2, 2]))             if i < self.routings - 1:                 b += K.batch_dot(outputs, inputs_hat, [2, 3])         return outputs      def compute_output_shape(self, input_shape):         return input_shape  # 定义Length层 class Length(layers.Layer):     def call(self, inputs, **kwargs):         return K.sqrt(K.sum(K.square(inputs), -1))      def compute_output_shape(self, input_shape):         return input_shape[:-1]  # 构建Capsule网络模型 model = CapsuleModel((28, 28, 1), 10, 3) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) 

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!