费雪信息
之前写过的文章20240616日志:大模型压缩方法DMS里具体介绍了费雪信息,在一组观测数据中,Fisher信息量越大,对未知参数的估计就越准确。
I w def = E [ ( ∂ ∂ w log p ( D ∣ w ) ) 2 ] (1) I_w^{\text{def}}=\mathbb{E}\left[\left(\frac{\partial}{\partial w}\log p(\mathcal{D}|w)\right)^2\right]\tag{1} Iwdef=E[(∂w∂logp(D∣w))2](1)
但是,Fisher信息量计算代价太大。
FWSVD: Fisher-Weighted SVD
寻求一个基于经验的费雪信息量 I w e m p I_w^{\mathrm{emp}} Iwemp,用公式2表示
I w d e f ≈ I w e m p = 1 ∣ D ∣ ∑ i = 1 ∣ D ∣ ( ∂ ∂ w L ( d i ; w ) ) 2 (2) I_w^{\mathrm{def}}\approx I_w^{\mathrm{emp}}=\frac{1}{|\mathcal{D}|}\sum_{i=1}^{|\mathcal{D}|}\left(\frac{\partial}{\partial w}\mathcal{L}\left(d_i;w\right)\right)^2\tag{2} Iwdef≈Iwemp=∣D∣1i=1∑∣D∣(∂w∂L(di;w))2(2)
令 w ^ = S V D ( w ) \hat{w} = \mathrm{SVD}(w) w^=SVD(w),由此目标函数定义为
min rank w ^ = r ∥ I w emp ∗ ( w − w ^ ) ∥ 2 (3) \min_{\text{rank }\hat{w}=r}\|\sqrt{I_w^\text{emp}}*(w-\hat{w})\|^2\tag{3} rank w^=rmin∥Iwemp∗(w−w^)∥2(3)
对 I w e m p I_w^{\mathrm{emp}} Iwemp使用行加权
I ^ w e m p = d i a g ( I w e m p ⋅ 1 ) (4) \hat{I}_w^{\mathrm{emp}}=\mathrm{diag}\left(I_w^{\mathrm{emp}}\cdot\mathbf{1}\right)\tag{4} I^wemp=diag(Iwemp⋅1)(4)
由此可得加权的SVD
F W S V D ( w ) ≈ U ^ Σ ^ V ^ = ( I ^ w e m p ) − 1 U Σ V (5) \mathrm{FWSVD}(w)\approx\hat{U}\hat{\Sigma}\hat{V}=(\hat{I}_{w}^{\mathrm{emp}})^{-1}U\Sigma V\tag{5} FWSVD(w)≈U^Σ^V^=(I^wemp)−1UΣV(5)
增强LoRA FWSVD压缩
为了不计算每一个权重的信息量,这里引入
I ^ w e m p ≈ I ^ Δ w e m p = I ^ B e m p I ^ A e m p (6) \hat{I}_w^{\mathrm{emp}}\approx\hat{I}_{\Delta w}^{\mathrm{emp}}=\hat{I}_B^{\mathrm{emp}}\hat{I}_A^{\mathrm{emp}}\tag{6} I^wemp≈I^Δwemp=I^BempI^Aemp(6)
这里的"≈"的意思是使用后面的 I ^ Δ w e m p \hat{I}_{\Delta w}^{\mathrm{emp}} I^Δwemp去近似前面的 I ^ w e m p \hat{I}_w^{\mathrm{emp}} I^wemp, Δ w \Delta w Δw的意思并不是变化率。使用LoRA的思想进行微调,然后把 w + Δ w w+\Delta w w+Δw代替原来的 w w w,然后使用 I ^ Δ w \hat{I}_{\Delta w} I^Δw压缩
F W S V D ( w ) ≈ F W S V D ( Δ w ) = S V D ( I ^ Δ w e m p ( w + Δ w ) ) = S V D ( I ^ Δ w e m p w ) (7) \begin{aligned} \mathrm{FWSVD}(w)& \approx\mathrm{FWSVD}(\Delta w) \\ &=\mathrm{SVD}(\hat{I}_{\Delta w}^{\mathrm{emp}}(w+\Delta w)) \\ &=\mathrm{SVD}(\hat{I}_{\Delta w}^{\mathrm{emp}}w) \end{aligned}\tag{7} FWSVD(w)≈FWSVD(Δw)=SVD(I^Δwemp(w+Δw))=SVD(I^Δwempw)(7)
reference
[1]:ACL 2024 Parameter and Memory Efficient Language Model Compression using Fisher Informationfrom Low-Rank Representations