阅读量:4
在PyTorch中进行超参数搜索通常有两种常用的方法:
- 使用Grid Search:通过定义一个超参数的候选值列表,对所有可能的组合进行穷举搜索,选择表现最好的超参数组合。可以使用GridSearchCV类来实现这一过程。
from sklearn.model_selection import GridSearchCV from sklearn.metrics import accuracy_score from torch import nn, optim from torch.utils.data import DataLoader # Define your model class MyModel(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(MyModel, self).__init__() self.hidden = nn.Linear(input_dim, hidden_dim) self.relu = nn.ReLU() self.output = nn.Linear(hidden_dim, output_dim) def forward(self, x): x = self.hidden(x) x = self.relu(x) x = self.output(x) return x # Define your dataset and dataloader # dataset = ... # dataloader = DataLoader(dataset, batch_size=64, shuffle=True) # Define parameter grid param_grid = { 'hidden_dim': [64, 128, 256], 'learning_rate': [0.001, 0.01, 0.1] } # Create a GridSearchCV object grid_search = GridSearchCV(MyModel, param_grid, scoring='accuracy', cv=3) # Fit the model grid_search.fit(dataloader) # Print best parameters print(grid_search.best_params_)
- 使用Random Search:与Grid Search不同,Random Search是随机地在指定的参数空间中采样,从而更有效地搜索超参数空间。可以使用RandomizedSearchCV类来实现这一过程。
from sklearn.model_selection import RandomizedSearchCV from sklearn.metrics import accuracy_score from torch import nn, optim from torch.utils.data import DataLoader # Define your model # Define your model class MyModel(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(MyModel, self).__init__() self.hidden = nn.Linear(input_dim, hidden_dim) self.relu = nn.ReLU() self.output = nn.Linear(hidden_dim, output_dim) def forward(self, x): x = self.hidden(x) x = self.relu(x) x = self.output(x) return x # Define your dataset and dataloader # dataset = ... # dataloader = DataLoader(dataset, batch_size=64, shuffle=True) # Define parameter grid param_dist = { 'hidden_dim': [64, 128, 256], 'learning_rate': [0.001, 0.01, 0.1] } # Create a RandomizedSearchCV object random_search = RandomizedSearchCV(MyModel, param_dist, n_iter=10, scoring='accuracy', cv=3) # Fit the model random_search.fit(dataloader) # Print best parameters print(random_search.best_params_)
无论选择哪种方法,超参数搜索是一个耗时的过程,需要谨慎选择超参数的范围和步长,以及合适的评估指标来评估模型性能。