tf.keras.utils.Sequence在批次大小小于最后一个批次时会忽略最后一个批次。

huangapple go评论62阅读模式
英文:

tf.keras.utils.Sequence is ignoring the last batch when it is smaller than the batchsize

问题

我使用keras.utils.Sequence为我的tensorflow模型生成数据。然而,我意识到当迭代生成器时,小于batch_size的最后一个批次不会被返回,从而忽略了一些数据。

这是一些说明我的问题的代码:
tensorflow版本:2.10.0
Python版本:3.9.5

class TEST_DATA_GENERATOR(tf.keras.utils.Sequence):
    def __init__(
        self,
    ):
        self.samples = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        self.model_name = model_config["name"]
        self.batch_size = 4
        self.shuffle = False
        self.indices = range(0, len(self.samples))
        assert self.batch_size <= len(self.indices), "batch size must be smaller than the number of samples"
        self.on_epoch_end()  # shuffle

    def __len__(self):
        return len(self.indices) // self.batch_size

    def __getitem__(self, index):
        index = self.index[index * self.batch_size:(index + 1) * self.batch_size]
        batch = [self.indices[k] for k in index]

        X, y = self.__get_data(batch)
        return X, y

    def on_epoch_end(self):
        # 在创建时(以及在每个epoch结束时)调用,然后将所有索引分配为self.index的列表
        self.index = np.arange(len(self.indices))
        if self.shuffle == True:
            np.random.shuffle(self.index)

    def __get_data(self, batch):
        # batch = 值的列表
        X = []
        y = []
        for i in range(len(batch)):
            y.append("classlabel")

        for batch_idx, sample_idx in enumerate(batch):
            X.append(self.samples[sample_idx])

        X = np.asarray(X)
        y = np.asarray(y)
        return X, y


testgen = TEST_DATA_GENERATOR()

# 检查加载数据的形状
x, y = testgen.__getitem__(0)
print(x.shape)

# 检查加载数据的形状
x, y = testgen.__getitem__(1)
print(x.shape)

# 检查加载数据的形状
x, y = testgen.__getitem__(2)
print(x.shape)


print("----")
for x, y in testgen.__iter__():
    print(x.shape)

输出是:

(4,)
(4,)
(2,)

(4,)
(4,)

所以,当逐个获取所有项目时,它会返回所有批次,甚至最后一个仅包含2个元素的批次。但在迭代时,最后一个批次被丢弃,这不是我想要的。有什么解决方案吗?

英文:

I use keras.utils.Sequence to generate data for my tensorflow model. However i realised when iterating over the generator, the last batch that is smaller than the batch_size is not returned thus ignoring some data.

Here is some code that illustrates my problem:
tensorflow version: 2.10.0
Python version: 3.9.5

    class TEST_DATA_GENERATOR(tf.keras.utils.Sequence):
    def __init__(
        self,
    ):
        self.samples = [1,2,3,4,5,6,7,8,9,10]
        self.model_name = model_config["name"]
        self.batch_size = 4
        self.shuffle = False
        self.indices = range(0, len(self.samples))
        assert self.batch_size <= len(self.indices), "batch size must be smaller than the number of samples"
        self.on_epoch_end()  # shuffle
    
    def __len__(self):
        return len(self.indices) // self.batch_size
    
    def __getitem__(self, index):
        index = self.index[index * self.batch_size:(index + 1) * self.batch_size]
        batch = [self.indices[k] for k in index]
    
        X, y = self.__get_data(batch)
        return X, y
    
    def on_epoch_end(self):
        # is called once at creation (an on every epoch end) and then assignes all indecies as a list to self.index
        self.index = np.arange(len(self.indices))
        if self.shuffle == True:
            np.random.shuffle(self.index)
    
    def __get_data(self, batch):
        # batch = list of values
        X = []
        y = []
        for i in range(len(batch)):
            y.append("classlabel")
    
        for batch_idx, sample_idx in enumerate(batch):
            X.append(self.samples[sample_idx])
        
        X = np.asarray(X)
        y = np.asarray(y)
        return X, y
    
    
    testgen = TEST_DATA_GENERATOR()
    
    # Check shapes of loaded data
    x,y = testgen.__getitem__(0)
    print(x.shape)

    # Check shapes of loaded data
    x,y = testgen.__getitem__(1)
    print(x.shape)

    # Check shapes of loaded data
    x,y = testgen.__getitem__(2)
    print(x.shape)

    
    print("----")
    for x,y in testgen.__iter__():
        print(x.shape)

The output is:

(4,)
(4,)
(2,)
----
(4,)
(4,)

So when grabbing all of the items one by one it returns alle batches even the last one with only 2 elements. but when iteration the last batch is discarded which i don't want. Any solutions?

答案1

得分: 3

问题在于您的__len__方法实现不正确。

您正在使用整数除法,而应该使用math.ceil,只需比较这两个输出:

>>> 10 // 4
2

>>> math.ceil(10/4)
3
英文:

The problem is that your __len__ method is implemented incorrectly.

You are using integer division, instead you should be using math.ceil, just compare these two outputs:

>>> 10 // 4
2
>>> math.ceil(10/4)
3

huangapple
  • 本文由 发表于 2023年2月23日 22:09:32
  • 转载请务必保留本文链接:https://go.coder-hub.com/75545931.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定