阅读量:0
目录
安装
pip install pykan
初始化kan
from kan import * torch.set_default_dtype(torch.float64) # create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5). model = KAN(width=[2,5,1], grid=3, k=3, seed=42)
checkpoint directory created: ./model saving model version 0.0
创建数据集
from kan.utils import create_dataset # create dataset f(x,y) = exp(sin(pi*x)+y^2) f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) dataset = create_dataset(f, n_var=2) dataset['train_input'].shape, dataset['train_label'].shape
(torch.Size([1000, 2]), torch.Size([1000, 1]))