英文:
ray tune batch_size should be a positive integer value, but got batch_size=<ray.tune.search.sample.Categorical object
问题
我正在尝试使用Ray来调整神经网络。我按照标准流程来运行它在MNIST数据上。数据加载部分如下:
trainset = torchvision.datasets.MNIST(
root='../data', train=True, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
testset = torchvision.datasets.MNIST(
root='../data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_loader = torch.utils.data.DataLoader(
trainset,
batch_size=config_set["batch_size"], shuffle=True)
test_loader = torch.utils.data.DataLoader(
testset,
batch_size=1000, shuffle=True)
当我们使用可配置的超参数运行调整时,它抛出错误:
config_set = {
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([16, 32, 64, 128])
}
result = tune.run(
train_model, fail_fast="raise", config=config_set)
*** ValueError: batch_size应该是一个正整数值,但得到的是batch_size=<ray.tune.search.sample.Categorical object at ***
英文:
I am trying to tune a neural network using ray. I follow the standard flow to get it running on MNIST data. Data loading
trainset = torchvision.datasets.MNIST(
root='../data', train=True, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
testset = torchvision.datasets.MNIST(
root='../data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_loader = torch.utils.data.DataLoader(
trainset,
batch_size=config_set["batch_size"], shuffle=True)
test_loader = torch.utils.data.DataLoader(
testset,
batch_size=1000, shuffle=True)
when we run the tune with the configurable hyper parameters, it throws error
config_set = {
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([16, 32, 64,128])
}
result = tune.run(
train_model, fail_fast="raise", config=config_set)
*** ValueError: batch_size should be a positive integer value, but got batch_size=<ray.tune.search.sample.Categorical object at ***
答案1
得分: 2
对于自定义训练代码,Tune允许你将其封装在一个Function Trainable中,该Function Trainable将被传递给Tune,并为你提供一个解析后的配置字典。目前,你正在传递未解析的搜索空间对象(即由tune.choice
生成的分类对象)。
from ray import air, tune
from ray.air import session
# 将其封装在一个函数中
def trainable(config: dict):
# 你的训练代码...
trainset = torchvision.datasets.MNIST(
root='../data', train=True, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
testset = torchvision.datasets.MNIST(
root='../data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_loader = torch.utils.data.DataLoader(
trainset,
batch_size=config["batch_size"], shuffle=True)
train_model(...)
config_set = {
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([16, 32, 64, 128])
}
tuner = tune.Tuner(
trainable,
param_space=config_set,
run_config=air.RunConfig(
failure_config=air.FailureConfig(fail_fast="raise")
),
)
results = tuner.fit()
英文:
For custom training code, Tune allows you to wrap it in a Function Trainable, which gets passed into Tune and provides you with a resolved config dict. Currently, you're passing in the unresolved search space object (the categorical object resulting from tune.choice
).
from ray import air, tune
from ray.air import session
# Wrap it in a function
def trainable(config: dict):
# Your training code...
trainset = torchvision.datasets.MNIST(
root='../data', train=True, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
testset = torchvision.datasets.MNIST(
root='../data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_loader = torch.utils.data.DataLoader(
trainset,
batch_size=config["batch_size"], shuffle=True)
train_model(...)
config_set = {
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([16, 32, 64,128])
}
tuner = tune.Tuner(
trainable,
param_space=config_set,
run_config=air.RunConfig(
failure_config=air.FailureConfig(fail_fast="raise")
),
)
results = tuner.fit()
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论