pytorch set_epoch()方法

在分布式模式下,需要在每个 epoch 开始时调用 set_epoch() 方法,然后再创建 DataLoader 迭代器,以使 shuffle 操作能够在多个 epoch 中正常工作。 否则,dataloader 迭代器产生的数据将始终使用相同的顺序。

1
2
3
4
5
6
7
sampler = DistributedSampler(dataset) if is_distributed else None
loader = DataLoader(dataset, shuffle=(sampler is None),
sampler=sampler)
for epoch in range(start_epoch, n_epochs):
if is_distributed:
sampler.set_epoch(epoch)
train(loader)

参考

https://docs.pytorch.org/docs/stable/data.html

https://zhuanlan.zhihu.com/p/97115875


pytorch set_epoch()方法
http://yojayc.github.io/2022/03/21/pytorch-set-epoch-方法/
作者
Truman
发布于
2022年3月21日
更新于
2026年4月5日
许可协议