I would like to finetune the blip model on ROCO data set for image captioning of chest x-rays

huangapple go评论54阅读模式
英文:

I would like to finetune the blip model on ROCO data set for image captioning of chest x-rays

问题

以下是代码的翻译部分:

我想在 ROCO 数据库上微调 BLIP 模型用于图像字幕化胸部 X 光图像但是我遇到了一个关于整数索引的错误

有谁可以帮助我理解错误的原因以及如何纠正它

这是代码

def read_data(filepath, csv_path, n_samples):
    df = pd.read_csv(csv_path)
    images = []
    capts = []
    for idx in range(len(df)):
        if 'hest x-ray' in df['caption'][idx] or 'hest X-ray' in df['caption'][idx]:
            if len(images) > n_samples:
                break            
            else:
                images.append(Image.open(os.path.join(filepath, df['name'][idx]).convert('L'))
                capts.append(df['caption'][idx])
    return images, capts

def get_data():
    imgtrpath = 'all_data/train/radiology/images'
    trcsvpath = 'all_data/train/radiology/traindata.csv'
    imgtspath = 'all_data/test/radiology/images'
    tscsvpath = 'all_data/test/radiology/testdata.csv'
    imgvalpath = 'all_data/validation/radiology/images'
    valcsvpath = 'all_data/validation/radiology/valdata.csv'

    print('提取训练数据')
    trainimgs, traincapts = read_data(imgtrpath, trcsvpath, 1800)
    
    print('提取测试数据')
    testimgs, testcapts = read_data(imgtrpath, trcsvpath, 100)
    
    print('提取验证数据')
    valimgs, valcapts = read_data(imgtrpath, trcsvpath, 100)

    return trainimgs, traincapts, testimgs, testcapts, valimgs, valcapts

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

trainimgs, traincapts, testimgs, testcapts, valimgs, valcapts = get_data() 
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")

metric = evaluate.load("accuracy")
traindata = processor(text=traincapts, images=trainimgs, return_tensors="pt", padding=True, truncation=True)
evaldata =  processor(text=testcapts, images=testimgs, return_tensors="pt", padding=True, truncation=True)
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=traindata,
    eval_dataset=evaldata,
    compute_metrics=compute_metrics
)
trainer.train()

这段代码的目的是对 ROCO 数据集的胸部 X 光图像进行图像字幕化的 BLIP 模型微调。但是当我运行它时,出现了以下错误:

  File "C:\Users\omair\anaconda3\envs\torch\lib\site-packages\transformers\feature_extraction_utils.py", line 86, in __getitem__
raise KeyError("Indexing with integers is not available when using Python based feature extractors")
KeyError: '使用基于Python的特征提取器时,不支持使用整数进行索引'

如果需要进一步的帮助,可以提出具体问题。

英文:

I want to fine tune the blip model on ROCO database for image captioning chest x-ray images. But I am getting an error regarding integer indexing.

Can anyone please help me understand the cause of the error and how to rectify it.

This is the code:

def read_data(filepath,csv_path,n_samples):
df = pd.read_csv(csv_path)
images = []
capts = []
for idx in range(len(df)):
if 'hest x-ray' in df['caption'][idx] or 'hest X-ray' in df['caption'][idx]:
if len(images)>n_samples:
break            
else:
images.append(Image.open(os.path.join(filepath,df['name'][idx])).convert('L'))
capts.append(df['caption'][idx])
return images, capts
def get_data():
imgtrpath = 'all_data/train/radiology/images'
trcsvpath = 'all_data/train/radiology/traindata.csv'
imgtspath = 'all_data/test/radiology/images'
tscsvpath = 'all_data/test/radiology/testdata.csv'
imgvalpath = 'all_data/validation/radiology/images'
valcsvpath = 'all_data/validation/radiology/valdata.csv'
print('Extracting Training Data')
trainimgs, traincapts = read_data(imgtrpath, trcsvpath, 1800)
print('Extracting Testing Data')
testimgs, testcapts = read_data(imgtrpath, trcsvpath, 100)
print('Extracting Validation Data')
valimgs, valcapts = read_data(imgtrpath, trcsvpath, 100)
return trainimgs, traincapts, testimgs, testcapts, valimgs, valcapts
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
trainimgs, traincapts, testimgs, testcapts, valimgs, valcapts = get_data() 
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
metric = evaluate.load("accuracy")
traindata = processor(text=traincapts, images=trainimgs, return_tensors="pt", padding=True, truncation=True)
evaldata =  processor(text=testcapts, images=testimgs, return_tensors="pt", padding=True, truncation=True)
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=traindata,
eval_dataset=evaldata,
compute_metrics=compute_metrics
)
trainer.train()

The code is meant to fine-tune the BLIP model on the ROCO dataset chest x-ray images for the purpose of image captioning.
But when I run it, I am getting this error:

  File "C:\Users\omair\anaconda3\envs\torch\lib\site-packages\transformers\feature_extraction_utils.py", line 86, in __getitem__
raise KeyError("Indexing with integers is not available when using Python based feature extractors")
KeyError: 'Indexing with integers is not available when using Python based feature extractors'

答案1

得分: 0

有两个问题:

  1. 在训练过程中,您没有提供标签,您的...capts被传递为模型的“Question”。在下面的链接中有一个如何做的示例。
  2. 目前不支持微调HF的BlipForConditionalGeneration,请参见https://discuss.huggingface.co/t/finetune-blip-on-customer-dataset-20893/28446,他们刚刚修复了BlipForQuestionAnswering。如果您基于此链接创建数据集,您还将遇到错误ValueError: Expected input batch_size (0) to match target batch_size (511),如果您努力复制对BlipForQuestionAnswering所做更改,则可以解决此问题。
英文:

There are two issues here:

  1. You're not providing the labels during training, your ...capts are passed as the model's "Question". There is an example on how to do that in the link below.
  2. Finetuning HF's BlipForConditionalGeneration is not supported at the moment, see https://discuss.huggingface.co/t/finetune-blip-on-customer-dataset-20893/28446 where they just fixed BlipForQuestionAnswering. If you create a dataset based on this link, you will also get the error ValueError: Expected input batch_size (0) to match target batch_size (511). which can be solved if you put the effort to reproduce the changes made on BlipForQuestionAnswering to BlipForConditionalGeneration.

huangapple
  • 本文由 发表于 2023年2月13日 23:26:37
  • 转载请务必保留本文链接:https://go.coder-hub.com/75437911.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定