ImageDataGenerator 关于输入长度的问题

huangapple go评论64阅读模式
英文:

ImageDataGenerator regarding length of inputs

问题

当我将这两个列表传递给.flow()函数时,出现了以下错误信息:

ValueError: `x`(图像张量`y`(标签应具有相同的长度发现x.shape = (32, 32, 3)y.shape = (2, 4)

但是,当我打印x.shapey.shape时,它告诉我它们的长度相同。以下是我的代码:

dategen = ImageDataGenerator(
    rescale = 1./255
)

labels = list(df["label"])
filepaths = list(df["photo"])

features = []

for filepath in filepaths:
    img = Image.open(filepath)
    img = img.resize((32,32))
    features.append(np.array(img))

features = np.array(features)
labels = np.array(labels)

print(features.shape) #打印结果为 (2, 32, 32, 3)
print(labels.shape)   #打印结果为 (2, 4)

train_gen = datagen.flow(
    x = features,
    y = labels,
    batch_size = 1
)

我认为也许特征数组的第一个维度已经被修剪了,所以我尝试扩展特征数组的第一个维度:

features = np.expand_dims(features, axis = 0)
print(features.shape) #打印结果为 (1, 2, 32, 32, 3)

但然后我得到了这个错误:

ValueError: ... 发现x.shape = (1, 2, 32, 32, 3)...

发生了什么?

英文:

I have a list of features containing 3D numpy arrays representing a RGB image as well as a list of labels which are 1D numpy arrays. I want to perform image augmentation using the ImageDataGenerator.flow(). When passing the two lists into .flow(), I get:

ValueError: `x` (images tensor) and `y` (labels) should have the same length. Found: x.shape = (32, 32, 3), y.shape = (2, 4)

However, when I print x.shape and y.shape, it tells me that they are the same length. Here is my code:

dategen = ImageDataGenerator(
    rescale = 1./255
)


labels = list(df["label"])
filepaths = list(df["photo"])

features = []

for filepath in filepaths:
    img = Image.open(filepath)
    img = img.resize((32,32))
    features.append(np.array(img))

features = np.array(features)
labels = np.array(labels)

print(features.shape) #Prints (2, 32, 32, 3)
print(labels.shape)   #Prints (2, 4)

train_gen = datagen.flow(
    x = features,
    y = labels,
    batch_size = 1
)

I thought that perhaps the first dimension of the features array has been trimmed so I tried expanding the first dimension of the features array

features = np.expand_dims(features, axis = 0)
print(features.shape) #Prints (1, 2, 32, 32, 3)

But then I get this error:

ValueError: ... Found: x.shape = (1, 2, 32, 32, 3), ...

What is going on?

答案1

得分: 0

我通过添加这行代码解决了它。形状仍然相同,但在运行这行代码后它不知何故就起作用了。

labels = np.reshape(labels, (features.shape[0], -1))
英文:

I solved it by adding this line of code. The shape is still the same but somehow it works after running this.

labels = np.reshape(labels, (features.shape[0], -1))

huangapple
  • 本文由 发表于 2023年6月15日 12:19:57
  • 转载请务必保留本文链接:https://go.coder-hub.com/76479078.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定