【kan】hello, kan

avatar
作者
猴君
阅读量:0

目录

安装

初始化kan

创建数据集

训练KAN

剪枝 KAN

继续训练

设置公式

继续训练

获得符号公式


安装

pip install pykan

初始化kan

from kan import * torch.set_default_dtype(torch.float64) # create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5). model = KAN(width=[2,5,1], grid=3, k=3, seed=42)
checkpoint directory created: ./model saving model version 0.0

创建数据集

from kan.utils import create_dataset # create dataset f(x,y) = exp(sin(pi*x)+y^2) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) dataset = create_dataset(f, n_var=2) dataset['train_input'].shape, dataset['train_label'].shape
(torch.Size([1000, 2]), torch.Size([1000, 1]))

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!