What does it mean by 'TypeError: argument of type 'Adam' is not iterable'?

huangapple go评论60阅读模式
英文:

What does it mean by 'TypeError: argument of type 'Adam' is not iterable'?

问题

抱歉,您提供的代码有点长,我将提取其中的重要部分进行翻译:

# 定义用于HR和LR输入图像的路径
lr_train_dir = 'C:/data/lr_train_150/',
hr_train_dir = 'C:/data/hr_train_150/',
lr_valid_dir = 'C:/data/lr_test_150/',
hr_valid_dir = 'C:/data/hr_test_150/',

# 实例化模型
rrdn = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size)
discriminator = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
feature_extractor = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)

# 定义优化器和损失函数
optimizer = Adam(1e-4, beta_1=0.9, beta_2=0.999)
loss_weights = {
  'generator': 0.0,
  'feature_extractor': 0.0833,
  'discriminator': 0.01
}
losses = {
  'generator': 'mae',
  'feature_extractor': 'mse',
  'discriminator': 'binary_crossentropy'
} 
learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}
log_dirs = {'logs': './logs', 'weights': './weights'}
flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}

# 定义训练器
trainer = Trainer(generator=rrdn, 
                  discriminator=discriminator,
                  feature_extractor=feature_extractor, 
                  log_dirs=log_dirs,
                  learning_rate=learning_rate, 
                  losses=losses, 
                  flatness = flatness,
                  loss_weights = loss_weights,
                  adam_optimizer=optimizer,
                  lr_train_dir = 'C:/data/lr_train_150/',
                  hr_train_dir = 'C:/data/hr_train_150/',
                  lr_valid_dir = 'C:/data/lr_test_150/',
                  hr_valid_dir = 'C:/data/hr_test_150/'
                 )

# 训练模型
trainer.train(batch_size=16, 
              steps_per_epoch=20, 
              epochs=1, 
              monitored_metrics={'val_generator_PSNR_Y': 'max'}
             )

您提到的错误是:

TypeError: argument of type 'Adam' is not iterable

这个错误表明在某个地方,代码试图迭代一个类型为 'Adam' 的对象,但 'Adam' 对象不支持迭代。您可能需要检查代码的其他部分,看看是否有与 'Adam' 相关的问题。

希望这有助于您解决问题。如果您需要更多的帮助,请随时提出具体的问题。

英文:

> Hello, I'm trying to build a model that would output super resolution image from input of low and high resolution image, my first error was: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set model.trainable without calling model.compile after?, first I tried to solve this way, UserWarning: Discrepancy between trainable weights and collected trainable weights error(https://stackoverflow.com/questions/49091553/userwarning-discrepancy-between-trainable-weights-and-collected-trainable-weigh) but it didn't work since ISR doesn't have 'compile' attribute apparently. After reading through docs I did think I got it, but this time I'm getting this error, I know what not iterable means in general, I just don't understand how that's related in here.

import os
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from ISR.models import RRDN, Discriminator
from ISR.models import Cut_VGG19
from ISR.train import Trainer

# Define paths for HR and LR input images
lr_train_dir = 'C:/data/lr_train_150/',
hr_train_dir = 'C:/data/hr_train_150/',
lr_valid_dir = 'C:/data/lr_test_150/',
hr_valid_dir = 'C:/data/hr_test_150/',

lr_train_patch_size = 22 #size of my LR image
layers_to_extract = [5, 9]
scale = 10
hr_train_patch_size = lr_train_patch_size * scale # 220 Size of my HR image


# Instantiate models
rrdn = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size)
discriminator = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
feature_extractor = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)


# Define optimizer and loss function
optimizer = Adam(1e-4, beta_1=0.9, beta_2=0.999)
#loss = 'mse'
loss_weights = {
  'generator': 0.0,
  'feature_extractor': 0.0833,
  'discriminator': 0.01
}
losses = {
  'generator': 'mae',
  'feature_extractor': 'mse',
  'discriminator': 'binary_crossentropy'
} 
learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}
log_dirs = {'logs': './logs', 'weights': './weights'}
flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}


# Define trainer
trainer = Trainer(generator=rrdn, 
                  discriminator=discriminator,
                  feature_extractor=feature_extractor, 
                  #name='srgan', 
                  log_dirs=log_dirs,
                  #checkpoint_dir='./models', 
                  learning_rate=learning_rate, 
                  losses=losses, 
                  flatness = flatness,
                  loss_weights = loss_weights,
                  adam_optimizer=optimizer,
                  lr_train_dir = 'C:/data/lr_train_150/',
                  hr_train_dir = 'C:/data/hr_train_150/',
                  lr_valid_dir = 'C:/data/lr_test_150/',
                  hr_valid_dir = 'C:/data/hr_test_150/'
                 )

# Train the model
trainer.train(batch_size=16, 
              steps_per_epoch=20, 
              #validation_steps=10, 
              epochs=1, 
              #print_frequency=100
              monitored_metrics={'val_generator_PSNR_Y': 'max'}
             )

And error I'm getting:

TypeError                                 Traceback (most recent call last)
Cell In[18], line 52
     46 flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}
     47 # Define feature extractor
     48 #vgg = Model(inputs=rrdn.input, outputs=rrdn.get_layer('features').output)
     49 #vgg.trainable = False
     50 
     51 # Define trainer
---> 52 trainer = Trainer(generator=rrdn, 
     53                   discriminator=discriminator,
     54                   feature_extractor=feature_extractor, 
     55                   #name='srgan', 
     56                   log_dirs=log_dirs,
     57                   #checkpoint_dir='./models', 
     58                   learning_rate=learning_rate, 
     59                   losses=losses, 
     60                   flatness = flatness,
     61                   loss_weights = loss_weights,
     62                   adam_optimizer=optimizer,
     63                   lr_train_dir = 'C:/data/lr_train_150/',
     64                   hr_train_dir = 'C:/data/hr_train_150/',
     65                   lr_valid_dir = 'C:/data/lr_test_150/',
     66                   hr_valid_dir = 'C:/data/hr_test_150/'
     67                  )
     69 # Train the model
     70 trainer.train(train_lr_dir=lr_train_dir, train_hr_dir=hr_train_dir, 
     71               valid_lr_dir=lr_valid_dir, valid_hr_dir=hr_valid_dir, 
     72               batch_size=16, 
   (...)
     77               monitored_metrics={'val_generator_PSNR_Y': 'max'}
     78              )

File ~\anaconda3\envs\img_tf\lib\site-packages\ISR\train\trainer.py:104, in Trainer.__init__(self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, loss_weights, log_dirs, fallback_save_every_n_epochs, dataname, weights_generator, weights_discriminator, n_validation, flatness, learning_rate, adam_optimizer, losses, metrics)
    102 elif self.metrics['generator'] == 'PSNR':
    103     self.metrics['generator'] = PSNR
--> 104 self._parameters_sanity_check()
    105 self.model = self._combine_networks()
    107 self.settings = {}

File ~\anaconda3\envs\img_tf\lib\site-packages\ISR\train\trainer.py:163, in Trainer._parameters_sanity_check(self)
    151 check_parameter_keys(
    152     self.learning_rate,
    153     needed_keys=['initial_value'],
    154     optional_keys=['decay_factor', 'decay_frequency'],
    155     default_value=None,
    156 )
    157 check_parameter_keys(
    158     self.flatness,
    159     needed_keys=[],
    160     optional_keys=['min', 'increase_frequency', 'increase', 'max'],
    161     default_value=0.0,
    162 )
--> 163 check_parameter_keys(
    164     self.adam_optimizer,
    165     needed_keys=['beta1', 'beta2'],
    166     optional_keys=['epsilon'],
    167     default_value=None,
    168 )
    169 check_parameter_keys(self.log_dirs, needed_keys=['logs', 'weights'])

File ~\anaconda3\envs\img_tf\lib\site-packages\ISR\utils\utils.py:45, in check_parameter_keys(parameter, needed_keys, optional_keys, default_value)
     43 if needed_keys:
     44     for key in needed_keys:
---> 45         if key not in parameter:
     46             logger.error('{p} is missing key {k}'.format(p=parameter, k=key))
     47             raise

TypeError: argument of type 'Adam' is not iterable

ISR docs
Basic ISR implementation on collab When I tried this on collab I got the same warning: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set model.trainable without calling model.compile after ?
'Discrepancy between trainable weights and collected trainable'
I don't think code works the way it supposed to be, but I can't be sure, since I'm new.

答案1

得分: 0

以下是翻译好的部分:

从错误跟踪中可以看出,优化器定义上的健全性检查失败。查看image-super-resolution文档可以发现,你应该以与纯Tensorflow编码不同的方式定义它。你应该传递参数字典而不是Tensorflow对象。

optimizer = {'beta1': 0.9, 'beta2': 0.999, 'epsilon': None}

这样正确回答了原始问题。

至于更广泛的问题:其他训练问题可能是由不符合规范的输入数据引起的。例如,模型期望具有匹配名称的相同数量的LR/HR对,而并非所有计算机视觉数据集都遵循这种约定。另一个问题可能与图形格式本身有关。我已经准备了一个notebook,将您的模型与Urban 100数据连接起来。最后,总的来说,使用像ISR包这样的未维护软件时,应该预期会遇到困难。

英文:

From the error trace it is seen that sanity checks fail on the optimizer definition. A closer look into the docs of image-super-resolution shows that you should define it differently than in plain tensorflow coding. You should pass the param dictionary in place of a Tensorflow object.

optimizer = {'beta1': 0.9, 'beta2': 0.999, 'epsilon': None}

This properly answers the original question.

Now, as for the broader question: other training issues are likely caused by non-compliant input data. The model for instance expects the same number of LR/HR pairs with matching names, and not all computer vision datasets follow this convention. Another issues may be around the graphics format itself. I have prepared a notebook that attaches your model to Urban 100 data. Finally, on a general note, expect difficulties when working with unmaintained software like this ISR package.

huangapple
  • 本文由 发表于 2023年5月11日 12:22:30
  • 转载请务必保留本文链接:https://go.coder-hub.com/76224141.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定