pytorch训练时如何跳过坏样本_在dataset或dataloader中实现过滤

Dataset中__getitem__里捕获异常并返回None

PyTorch的DataLoader默认遇到__getitem__抛异常就直接中断训练,但你其实可以让它“跳过”——关键不是让__getitem__崩溃,而是让它安静地返回None,再在DataLoader外层过滤掉None样本。

常见错误是直接在__getitem__里print然后continue,但continue在函数里没意义;也有人想raise SkipSampleError再用自定义collate捕获,太重,且容易漏处理。

在__getitem__开头加try,把加载/解码/预处理逻辑全包进去出错时return None(不要raise,也不要return []或0)确保你的__len__不因此虚高——如果坏样本比例高,建议先做一次预扫描,剔除已知损坏路径

示例:def __getitem__(self, idx): try: img = Image.open(self.files[idx]).convert(‘RGB’); return self.transform(img) except: return None

自定义collate_fn过滤None样本

DataLoader拿到一批__getitem__结果后,会调用collate_fn合并成batch。默认的default_collate遇到None直接报TypeError: batch must contain tensors, numbers, dicts or lists,所以必须自己写一个能跳过None的版本。

注意:这个collate_fn不能简单删掉None就完事——batch size会变,可能影响BN统计、梯度累积、甚至DistributedSampler的负载均衡。

先[x for x in batch if x is not None]过滤如果过滤后为空列表,返回None或抛一个轻量异常(如RuntimeError("empty batch")),让上层重试该batch不要用torch.utils.data.dataloader.default_collate直接套用,它不处理None;要手动对非None元素调用torch.stack或torch.cat

示例:def collate_fn(batch): batch = [x for x in batch if x is not None]; return torch.utils.data.dataloader.default_collate(batch) if batch else None

用IterDataPipe替代Dataset(PyTorch 2.0+推荐)

如果你用的是PyTorch ≥ 2.0,torchdata.datapipes.iter比传统Dataset更适合做样本级过滤——它天然支持链式操作、惰性求值,且filter行为明确,不会污染__len__。

老方案里__len__和实际可用样本数不一致是个隐患,尤其在分布式训练中,DistributedSampler按原始__len__切分,会导致某些rank拿到大量None,甚至卡死。

用FileLister + OpenFiles + LoadImage等DataPipe组合,每一步都可加.filter().filter(lambda x: x is not None)或更具体的条件,比如lambda x: x.size[0] > 0 and x.size[1] > 0最后用Batch和Collate组装,不用操心None穿透问题

示例:dp = FileLister(root).open_files().load_image().filter(lambda x: x is not None).batch(32).collate()

警惕“跳过”引发的隐性偏移

样本被跳过本身不是问题,问题是它悄悄改变了数据分布——比如某类图片格式集中损坏,你一跳,该类就系统性缺失;或者训练集里5%的样本总因内存不足被跳过,模型根本没见过这些模式。

更麻烦的是,验证/测试阶段如果沿用同一套逻辑,指标就不可信:你不知道是模型差,还是评估时又跳了关键样本。

训练时记录被跳过的样本ID或路径,定期抽检原因(是路径错?编码坏?尺寸超限?)验证集绝对不要跳——应该提前清洗,或用try/except + logging.warning报错但不跳,强制人工介入如果跳过率 > 2%,别急着修代码,先查数据源质量

最常被忽略的一点:DataLoader(num_workers>0)下,子进程里的异常默认静默丢弃,你根本看不到报错信息——务必在__getitem__里加logging.error,并设置logging.basicConfig(level=logging.ERROR)

声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。