What does it mean by 'TypeError: argument of type 'Adam' is not iterable'?

huangapple go评论85阅读模式
英文:

What does it mean by 'TypeError: argument of type 'Adam' is not iterable'?

问题

抱歉,您提供的代码有点长,我将提取其中的重要部分进行翻译:

  1. # 定义用于HR和LR输入图像的路径
  2. lr_train_dir = 'C:/data/lr_train_150/',
  3. hr_train_dir = 'C:/data/hr_train_150/',
  4. lr_valid_dir = 'C:/data/lr_test_150/',
  5. hr_valid_dir = 'C:/data/hr_test_150/',
  6. # 实例化模型
  7. rrdn = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size)
  8. discriminator = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
  9. feature_extractor = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)
  10. # 定义优化器和损失函数
  11. optimizer = Adam(1e-4, beta_1=0.9, beta_2=0.999)
  12. loss_weights = {
  13. 'generator': 0.0,
  14. 'feature_extractor': 0.0833,
  15. 'discriminator': 0.01
  16. }
  17. losses = {
  18. 'generator': 'mae',
  19. 'feature_extractor': 'mse',
  20. 'discriminator': 'binary_crossentropy'
  21. }
  22. learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}
  23. log_dirs = {'logs': './logs', 'weights': './weights'}
  24. flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}
  25. # 定义训练器
  26. trainer = Trainer(generator=rrdn,
  27. discriminator=discriminator,
  28. feature_extractor=feature_extractor,
  29. log_dirs=log_dirs,
  30. learning_rate=learning_rate,
  31. losses=losses,
  32. flatness = flatness,
  33. loss_weights = loss_weights,
  34. adam_optimizer=optimizer,
  35. lr_train_dir = 'C:/data/lr_train_150/',
  36. hr_train_dir = 'C:/data/hr_train_150/',
  37. lr_valid_dir = 'C:/data/lr_test_150/',
  38. hr_valid_dir = 'C:/data/hr_test_150/'
  39. )
  40. # 训练模型
  41. trainer.train(batch_size=16,
  42. steps_per_epoch=20,
  43. epochs=1,
  44. monitored_metrics={'val_generator_PSNR_Y': 'max'}
  45. )

您提到的错误是:

  1. TypeError: argument of type 'Adam' is not iterable

这个错误表明在某个地方,代码试图迭代一个类型为 'Adam' 的对象,但 'Adam' 对象不支持迭代。您可能需要检查代码的其他部分,看看是否有与 'Adam' 相关的问题。

希望这有助于您解决问题。如果您需要更多的帮助,请随时提出具体的问题。

英文:

> Hello, I'm trying to build a model that would output super resolution image from input of low and high resolution image, my first error was: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set model.trainable without calling model.compile after?, first I tried to solve this way, UserWarning: Discrepancy between trainable weights and collected trainable weights error(https://stackoverflow.com/questions/49091553/userwarning-discrepancy-between-trainable-weights-and-collected-trainable-weigh) but it didn't work since ISR doesn't have 'compile' attribute apparently. After reading through docs I did think I got it, but this time I'm getting this error, I know what not iterable means in general, I just don't understand how that's related in here.

  1. import os
  2. from tensorflow.keras.optimizers import Adam
  3. from tensorflow.keras.models import Model
  4. from ISR.models import RRDN, Discriminator
  5. from ISR.models import Cut_VGG19
  6. from ISR.train import Trainer
  7. # Define paths for HR and LR input images
  8. lr_train_dir = 'C:/data/lr_train_150/',
  9. hr_train_dir = 'C:/data/hr_train_150/',
  10. lr_valid_dir = 'C:/data/lr_test_150/',
  11. hr_valid_dir = 'C:/data/hr_test_150/',
  12. lr_train_patch_size = 22 #size of my LR image
  13. layers_to_extract = [5, 9]
  14. scale = 10
  15. hr_train_patch_size = lr_train_patch_size * scale # 220 Size of my HR image
  16. # Instantiate models
  17. rrdn = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size)
  18. discriminator = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
  19. feature_extractor = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)
  20. # Define optimizer and loss function
  21. optimizer = Adam(1e-4, beta_1=0.9, beta_2=0.999)
  22. #loss = 'mse'
  23. loss_weights = {
  24. 'generator': 0.0,
  25. 'feature_extractor': 0.0833,
  26. 'discriminator': 0.01
  27. }
  28. losses = {
  29. 'generator': 'mae',
  30. 'feature_extractor': 'mse',
  31. 'discriminator': 'binary_crossentropy'
  32. }
  33. learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}
  34. log_dirs = {'logs': './logs', 'weights': './weights'}
  35. flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}
  36. # Define trainer
  37. trainer = Trainer(generator=rrdn,
  38. discriminator=discriminator,
  39. feature_extractor=feature_extractor,
  40. #name='srgan',
  41. log_dirs=log_dirs,
  42. #checkpoint_dir='./models',
  43. learning_rate=learning_rate,
  44. losses=losses,
  45. flatness = flatness,
  46. loss_weights = loss_weights,
  47. adam_optimizer=optimizer,
  48. lr_train_dir = 'C:/data/lr_train_150/',
  49. hr_train_dir = 'C:/data/hr_train_150/',
  50. lr_valid_dir = 'C:/data/lr_test_150/',
  51. hr_valid_dir = 'C:/data/hr_test_150/'
  52. )
  53. # Train the model
  54. trainer.train(batch_size=16,
  55. steps_per_epoch=20,
  56. #validation_steps=10,
  57. epochs=1,
  58. #print_frequency=100
  59. monitored_metrics={'val_generator_PSNR_Y': 'max'}
  60. )

And error I'm getting:

  1. TypeError Traceback (most recent call last)
  2. Cell In[18], line 52
  3. 46 flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}
  4. 47 # Define feature extractor
  5. 48 #vgg = Model(inputs=rrdn.input, outputs=rrdn.get_layer('features').output)
  6. 49 #vgg.trainable = False
  7. 50
  8. 51 # Define trainer
  9. ---> 52 trainer = Trainer(generator=rrdn,
  10. 53 discriminator=discriminator,
  11. 54 feature_extractor=feature_extractor,
  12. 55 #name='srgan',
  13. 56 log_dirs=log_dirs,
  14. 57 #checkpoint_dir='./models',
  15. 58 learning_rate=learning_rate,
  16. 59 losses=losses,
  17. 60 flatness = flatness,
  18. 61 loss_weights = loss_weights,
  19. 62 adam_optimizer=optimizer,
  20. 63 lr_train_dir = 'C:/data/lr_train_150/',
  21. 64 hr_train_dir = 'C:/data/hr_train_150/',
  22. 65 lr_valid_dir = 'C:/data/lr_test_150/',
  23. 66 hr_valid_dir = 'C:/data/hr_test_150/'
  24. 67 )
  25. 69 # Train the model
  26. 70 trainer.train(train_lr_dir=lr_train_dir, train_hr_dir=hr_train_dir,
  27. 71 valid_lr_dir=lr_valid_dir, valid_hr_dir=hr_valid_dir,
  28. 72 batch_size=16,
  29. (...)
  30. 77 monitored_metrics={'val_generator_PSNR_Y': 'max'}
  31. 78 )
  32. File ~\anaconda3\envs\img_tf\lib\site-packages\ISR\train\trainer.py:104, in Trainer.__init__(self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, loss_weights, log_dirs, fallback_save_every_n_epochs, dataname, weights_generator, weights_discriminator, n_validation, flatness, learning_rate, adam_optimizer, losses, metrics)
  33. 102 elif self.metrics['generator'] == 'PSNR':
  34. 103 self.metrics['generator'] = PSNR
  35. --> 104 self._parameters_sanity_check()
  36. 105 self.model = self._combine_networks()
  37. 107 self.settings = {}
  38. File ~\anaconda3\envs\img_tf\lib\site-packages\ISR\train\trainer.py:163, in Trainer._parameters_sanity_check(self)
  39. 151 check_parameter_keys(
  40. 152 self.learning_rate,
  41. 153 needed_keys=['initial_value'],
  42. 154 optional_keys=['decay_factor', 'decay_frequency'],
  43. 155 default_value=None,
  44. 156 )
  45. 157 check_parameter_keys(
  46. 158 self.flatness,
  47. 159 needed_keys=[],
  48. 160 optional_keys=['min', 'increase_frequency', 'increase', 'max'],
  49. 161 default_value=0.0,
  50. 162 )
  51. --> 163 check_parameter_keys(
  52. 164 self.adam_optimizer,
  53. 165 needed_keys=['beta1', 'beta2'],
  54. 166 optional_keys=['epsilon'],
  55. 167 default_value=None,
  56. 168 )
  57. 169 check_parameter_keys(self.log_dirs, needed_keys=['logs', 'weights'])
  58. File ~\anaconda3\envs\img_tf\lib\site-packages\ISR\utils\utils.py:45, in check_parameter_keys(parameter, needed_keys, optional_keys, default_value)
  59. 43 if needed_keys:
  60. 44 for key in needed_keys:
  61. ---> 45 if key not in parameter:
  62. 46 logger.error('{p} is missing key {k}'.format(p=parameter, k=key))
  63. 47 raise
  64. TypeError: argument of type 'Adam' is not iterable

ISR docs
Basic ISR implementation on collab When I tried this on collab I got the same warning: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set model.trainable without calling model.compile after ?
'Discrepancy between trainable weights and collected trainable'
I don't think code works the way it supposed to be, but I can't be sure, since I'm new.

答案1

得分: 0

以下是翻译好的部分:

从错误跟踪中可以看出,优化器定义上的健全性检查失败。查看image-super-resolution文档可以发现,你应该以与纯Tensorflow编码不同的方式定义它。你应该传递参数字典而不是Tensorflow对象。

  1. optimizer = {'beta1': 0.9, 'beta2': 0.999, 'epsilon': None}

这样正确回答了原始问题。

至于更广泛的问题:其他训练问题可能是由不符合规范的输入数据引起的。例如,模型期望具有匹配名称的相同数量的LR/HR对,而并非所有计算机视觉数据集都遵循这种约定。另一个问题可能与图形格式本身有关。我已经准备了一个notebook,将您的模型与Urban 100数据连接起来。最后,总的来说,使用像ISR包这样的未维护软件时,应该预期会遇到困难。

英文:

From the error trace it is seen that sanity checks fail on the optimizer definition. A closer look into the docs of image-super-resolution shows that you should define it differently than in plain tensorflow coding. You should pass the param dictionary in place of a Tensorflow object.

  1. optimizer = {'beta1': 0.9, 'beta2': 0.999, 'epsilon': None}

This properly answers the original question.

Now, as for the broader question: other training issues are likely caused by non-compliant input data. The model for instance expects the same number of LR/HR pairs with matching names, and not all computer vision datasets follow this convention. Another issues may be around the graphics format itself. I have prepared a notebook that attaches your model to Urban 100 data. Finally, on a general note, expect difficulties when working with unmaintained software like this ISR package.

huangapple
  • 本文由 发表于 2023年5月11日 12:22:30
  • 转载请务必保留本文链接:https://go.coder-hub.com/76224141.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定