快捷方式

torch.utils.data

PyTorch 資料載入工具的核心是 torch.utils.data.DataLoader 類。它表示一個數據集上的 Python 可迭代物件,支援以下功能:

這些選項由 DataLoader 的建構函式引數配置,其簽名如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

以下部分詳細描述了這些選項的作用和用法。

資料集型別

DataLoader 建構函式最重要的引數是 dataset,它表示要從中載入資料的資料集物件。PyTorch 支援兩種不同型別的資料集:

對映風格資料集

對映風格資料集是實現了 __getitem__()__len__() 協議的資料集,表示從(可能非整數的)索引/鍵到資料樣本的對映。

例如,透過 dataset[idx] 訪問這樣的資料集時,可以從磁碟上的資料夾讀取第 idx 張影像及其對應的標籤。

詳情請參閱 Dataset

可迭代風格資料集

可迭代風格資料集是 IterableDataset 子類的例項,它實現了 __iter__() 協議,並表示資料樣本上的一個可迭代物件。此類資料集特別適用於隨機讀取開銷很大甚至不可能的情況,以及批次大小取決於獲取的資料的情況。

例如,呼叫 iter(dataset) 時,此類資料集可以返回一個數據流,資料來源可以是資料庫、遠端伺服器,甚至是即時生成的日誌。

詳情請參閱 IterableDataset

注意

IterableDataset多程序資料載入一起使用時,同一個資料集物件會在每個 worker 程序上覆制,因此必須對副本進行不同的配置以避免資料重複。有關如何實現此目的,請參閱 IterableDataset 文件。

資料載入順序和 Sampler

對於可迭代風格資料集,資料載入順序完全由使用者定義的可迭代物件控制。這使得實現分塊讀取和動態批次大小(例如,每次生成一個批次樣本)更加容易。

本節的其餘部分涉及對映風格資料集的情況。torch.utils.data.Sampler 類用於指定資料載入中使用的索引/鍵序列。它們表示資料集索引上的可迭代物件。例如,在隨機梯度下降 (SGD) 的常見情況下,Sampler 可以隨機排列索引列表並每次生成一個,或者為 mini-batch SGD 生成少量索引。

將根據 DataLoadershuffle 引數自動構建一個順序或隨機取樣器。或者,使用者可以使用 sampler 引數指定自定義的 Sampler 物件,該物件每次生成要獲取的下一個索引/鍵。

可以作為 batch_sampler 引數傳遞一個自定義的 Sampler,它每次生成一個批次索引列表。也可以透過 batch_sizedrop_last 引數啟用自動批次處理。有關更多詳細資訊,請參閱下一節

注意

samplerbatch_sampler 都不相容可迭代風格資料集,因為此類資料集沒有鍵或索引的概念。

載入批次和非批次資料

DataLoader 支援透過引數 batch_sizedrop_lastbatch_samplercollate_fn(它有一個預設函式)自動將單個獲取的資料樣本整理成批次。

自動批次處理(預設)

這是最常見的情況,對應於獲取一個 mini-batch 資料並將其整理成批次樣本,即包含一個維度作為批次維度(通常是第一個維度)的 Tensor。

batch_size(預設為 1)不是 None 時,資料載入器將生成批次樣本而不是單個樣本。batch_sizedrop_last 引數用於指定資料載入器如何獲取資料集鍵的批次。對於對映風格資料集,使用者也可以指定 batch_sampler,它每次生成一個鍵列表。

注意

batch_sizedrop_last 引數本質上用於從 sampler 構建一個 batch_sampler。對於對映風格資料集,sampler 由使用者提供或根據 shuffle 引數構建。對於可迭代風格資料集,sampler 是一個無限的虛擬取樣器。有關取樣器的更多詳細資訊,請參閱本節

注意

當使用多程序可迭代風格資料集獲取資料時,drop_last 引數會丟棄每個 worker 資料集副本的最後一個非完整批次。

在使用取樣器中的索引獲取樣本列表後,將使用作為 collate_fn 引數傳入的函式將樣本列表整理成批次。

在這種情況下,從對映風格資料集載入大致等同於:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

從可迭代風格資料集載入大致等同於:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

可以使用自定義的 collate_fn 來自定義整理過程,例如沿第一個維度以外的維度進行整理,填充各種長度的序列,或新增對自定義資料型別的支援。有關 collate_fn 的更多資訊,請參閱本節

停用自動批次處理

在某些情況下,使用者可能希望在資料集程式碼中手動處理批次處理,或者只加載單個樣本。例如,直接載入批次資料(例如,從資料庫批次讀取或讀取連續記憶體塊)可能更便宜,或者批次大小依賴於資料,或者程式設計用於處理單個樣本。在這些場景下,可能最好不要使用自動批次處理(即使用 collate_fn 整理樣本),而是讓資料載入器直接返回 dataset 物件的每個成員。

batch_sizebatch_sampler 均為 None 時(batch_sampler 的預設值已經是 None),自動批次處理被停用。從 dataset 獲取的每個樣本都將由作為 collate_fn 引數傳入的函式處理。

停用自動批次處理時,預設的 collate_fn 只是簡單地將 NumPy 陣列轉換為 PyTorch Tensor,而保留其他所有內容不變。

在這種情況下,從對映風格資料集載入大致等同於:

for index in sampler:
    yield collate_fn(dataset[index])

從可迭代風格資料集載入大致等同於:

for data in iter(dataset):
    yield collate_fn(data)

有關 collate_fn 的更多資訊,請參閱本節

使用 collate_fn

當自動批次處理啟用或停用時,collate_fn 的用法略有不同。

停用自動批次處理時collate_fn 會針對每個單獨的資料樣本呼叫,其輸出從資料載入器迭代器中生成。在這種情況下,預設的 collate_fn 只是簡單地將 NumPy 陣列轉換為 PyTorch Tensor。

啟用自動批次處理時collate_fn 每次都以資料樣本列表的形式呼叫。它應該將輸入的樣本整理成一個批次,以便從資料載入器迭代器中生成。本節的其餘部分描述了預設 collate_fn (default_collate()) 的行為。

例如,如果每個資料樣本包含一個 3 通道影像和一個整數類別標籤,即資料集的每個元素都返回一個元組 (image, class_index),則預設的 collate_fn 會將這樣的元組列表整理成一個包含批次影像張量和批次類別標籤張量的單個元組。特別是,預設的 collate_fn 具有以下特性:

  • 它始終將一個新的維度作為批次維度新增到最前面。

  • 它自動將 NumPy 陣列和 Python 數值轉換為 PyTorch Tensor。

  • 它保留資料結構,例如,如果每個樣本是一個字典,它會輸出一個具有相同鍵集但值為批次 Tensor(如果值不能轉換為 Tensor 則為列表)的字典。對於 listtuplenamedtuple 等也是如此。

使用者可以使用自定義的 collate_fn 來實現自定義批次處理,例如沿第一個維度以外的維度進行整理,填充各種長度的序列,或新增對自定義資料型別的支援。

如果 DataLoader 的輸出維度或型別與您的預期不同,您可能需要檢查您的 collate_fn

單程序和多程序資料載入

預設情況下,DataLoader 使用單程序資料載入。

在 Python 程序內部,全域性直譯器鎖 (GIL) 阻止了 Python 程式碼線上程之間真正的完全並行化。為了避免資料載入阻塞計算程式碼,PyTorch 提供了一個簡單的開關,只需將引數 num_workers 設定為正整數即可執行多程序資料載入。

單程序資料載入(預設)

在此模式下,資料獲取在初始化 DataLoader 的同一程序中完成。因此,資料載入可能會阻塞計算。但是,當用於程序間共享資料的資源(例如共享記憶體、檔案描述符)有限時,或者當整個資料集很小並且可以完全載入到記憶體中時,此模式可能更受歡迎。此外,單程序載入通常顯示更易讀的錯誤跟蹤,因此對於除錯很有用。

多程序資料載入

將引數 num_workers 設定為正整數將開啟多程序資料載入,並指定載入 worker 程序的數量。

警告

幾次迭代後,載入器 worker 程序將消耗與父程序相同的 CPU 記憶體量,用於訪問父程序中 worker 程序訪問的所有 Python 物件。如果 Dataset 包含大量資料(例如,您在 Dataset 構造時載入非常大的檔名列表)和/或您使用大量 worker,這可能會帶來問題(總記憶體使用量為 worker 數量 * 父程序大小)。最簡單的解決方法是用非引用計數表示(例如 Pandas、Numpy 或 PyArrow 物件)替換 Python 物件。請參閱 issue #13246,瞭解有關此問題發生原因和如何解決這些問題的示例程式碼的更多詳細資訊。

在此模式下,每次建立 DataLoader 的迭代器時(例如,當您呼叫 enumerate(dataloader) 時),會建立 num_workers 個 worker 程序。此時,datasetcollate_fnworker_init_fn 會傳遞給每個 worker,並在其中用於初始化和獲取資料。這意味著資料集訪問及其內部 IO、轉換(包括 collate_fn)都在 worker 程序中執行。

torch.utils.data.get_worker_info() 在 worker 程序中返回各種有用的資訊(包括 worker ID、資料集副本、初始種子等),在主程序中返回 None。使用者可以在資料集程式碼和/或 worker_init_fn 中使用此函式來單獨配置每個資料集副本,並確定程式碼是否在 worker 程序中執行。例如,這在資料集分片中特別有用。

對於對映風格資料集,主程序使用 sampler 生成索引並將其傳送給 worker。因此,任何洗牌隨機化都在主程序中完成,主程序透過分配要載入的索引來指導載入。

對於可迭代風格的資料集,由於每個工作程序都會獲得一個 dataset 物件的副本,簡單的多程序載入常常會導致資料重複。利用 torch.utils.data.get_worker_info() 和/或 worker_init_fn,使用者可以獨立配置每個副本。(參見 IterableDataset 文件瞭解如何實現)。出於類似原因,在多程序載入中,drop_last 引數會丟棄每個工作程序的可迭代風格資料集副本中的最後一個非完整批次。

當迭代結束時,或者當迭代器被垃圾回收時,工作程序就會關閉。

警告

通常不建議在多程序載入中返回 CUDA 張量,因為在多程序中使用 CUDA 和共享 CUDA 張量存在許多微妙之處(參見 多程序中的 CUDA)。相反,我們推薦使用自動記憶體鎖定(即設定 pin_memory=True),這能夠實現資料到支援 CUDA 的 GPU 的快速傳輸。

平臺特定的行為

由於工作程序依賴 Python 的 multiprocessing 模組,工作程序的啟動行為在 Windows 和 Unix 上是不同的。

  • 在 Unix 上,fork() 是預設的 multiprocessing 啟動方法。使用 fork(),子工作程序通常可以透過克隆的地址空間直接訪問 dataset 和 Python 引數函式。

  • 在 Windows 或 MacOS 上,spawn() 是預設的 multiprocessing 啟動方法。使用 spawn(),會啟動另一個直譯器,執行主指令碼,然後是接收透過 pickle 序列化傳遞的 datasetcollate_fn 和其他引數的內部工作函式。

這種單獨的序列化意味著在使用多程序資料載入時,您應該採取兩個步驟來確保與 Windows 相容

  • 將您主指令碼的大部分程式碼包裝在 if __name__ == '__main__': 塊中,以確保在啟動每個工作程序時它不會再次執行(很可能導致錯誤)。您可以將您的資料集和 DataLoader 例項建立邏輯放在這裡,因為它不需要在工作程序中重新執行。

  • 確保所有自定義的 collate_fnworker_init_fndataset 程式碼都宣告為頂級定義,在 __main__ 檢查之外。這確保它們在工作程序中可用。(這是必需的,因為函式僅作為引用被 pickle,而不是 bytecode)。

多程序資料載入中的隨機性

預設情況下,每個工作程序的 PyTorch 種子將被設定為 base_seed + worker_id,其中 base_seed 是主程序使用其 RNG 生成的長整型(因此必須消耗一個 RNG 狀態)或指定的 generator。然而,初始化工作程序時,其他庫的種子可能會重複,導致每個工作程序返回相同的隨機數。(參見 FAQ 中的本節)。

worker_init_fn 中,您可以透過 torch.utils.data.get_worker_info().seedtorch.initial_seed() 訪問為每個工作程序設定的 PyTorch 種子,並在載入資料之前使用它為其他庫設定種子。

記憶體鎖定

當資料來源自鎖定(頁鎖定)記憶體時,從主機到 GPU 的複製會快得多。有關何時以及如何通常使用鎖定記憶體的更多詳細資訊,請參見使用鎖定記憶體緩衝區

對於資料載入,將 pin_memory=True 傳遞給 DataLoader 會自動將獲取的資料張量放入鎖定記憶體中,從而實現資料到支援 CUDA 的 GPU 的快速傳輸。

預設的記憶體鎖定邏輯只識別張量以及包含張量的對映和可迭代物件。預設情況下,如果鎖定邏輯看到一個批次是自定義型別(如果您有一個返回自定義批次型別的 collate_fn,就會發生這種情況),或者如果您的批次的每個元素是自定義型別,鎖定邏輯將無法識別它們,並將直接返回該批次(或這些元素),而不鎖定記憶體。要為自定義批次或資料型別啟用記憶體鎖定,請在您的自定義型別上定義一個 pin_memory() 方法。

參見下面的示例。

示例

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='', in_order=True)[source][source]

資料載入器結合了資料集和取樣器,並提供了給定資料集上的可迭代物件。

DataLoader 支援單程序或多程序載入的 map 風格和可迭代風格資料集,支援定製載入順序以及可選的自動批處理(整理)和記憶體鎖定。

更多詳細資訊,請參見 torch.utils.data 文件頁面。

引數
  • dataset (Dataset) – 要從中載入資料的資料集。

  • batch_size (int, optional) – 每個批次要載入多少個樣本(預設值:1)。

  • shuffle (bool, optional) – 設定為 True 可在每個 epoch 重新打亂資料(預設值:False)。

  • sampler (Sampler or Iterable, optional) – 定義了從資料集中抽取樣本的策略。可以是實現了 __len__ 的任何 Iterable。如果指定,則不能指定 shuffle

  • batch_sampler (Sampler or Iterable, optional) – 類似於 sampler,但一次返回一批索引。與 batch_sizeshufflesamplerdrop_last 互斥。

  • num_workers (int, optional) – 用於資料載入的子程序數。0 表示資料將在主程序中載入(預設值:0)。

  • collate_fn (Callable, optional) – 將樣本列表合併以形成張量的小批次。在從 map 風格資料集進行批處理載入時使用。

  • pin_memory (bool, optional) – 如果為 True,資料載入器將在返回張量之前將其複製到裝置/CUDA 鎖定記憶體中。如果您的資料元素是自定義型別,或者您的 collate_fn 返回的批次是自定義型別,請參見下面的示例。

  • drop_last (bool, optional) – 如果資料集大小不能被批次大小整除,設定為 True 可丟棄最後一個不完整的批次。如果為 False 且資料集大小不能被批次大小整除,則最後一個批次會較小(預設值:False)。

  • timeout (numeric, optional) – 如果為正數,則為從工作程序收集批次的超時時間。應始終為非負數(預設值:0)。

  • worker_init_fn (Callable, optional) – 如果不為 None,將在每個工作子程序上呼叫此函式,輸入為工作程序 ID(一個介於 [0, num_workers - 1] 的整數),在設定種子之後、載入資料之前(預設值:None)。

  • multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – 如果為 None,將使用作業系統的預設多程序上下文(預設值:None)。

  • generator (torch.Generator, optional) – 如果不為 None,此 RNG 將由 RandomSampler 用於生成隨機索引,並由多程序用於為工作程序生成 base_seed(預設值:None)。

  • prefetch_factor (int, optional, keyword-only arg) – 每個工作程序預載入的批次數。2 意味著所有工作程序總共預載入 2 * num_workers 個批次(預設值取決於 num_workers 的設定值。如果 num_workers=0,預設值為 None。否則,如果 num_workers > 0,預設值為 2)。

  • persistent_workers (bool, optional) – 如果為 True,資料載入器在資料集被消費一次後不會關閉工作程序。這使得工作程序的 Dataset 例項保持活動狀態(預設值:False)。

  • pin_memory_device (str, optional) – 如果 pin_memoryTrue 時,進行 pin_memory 的裝置。如果未給出,當前加速器將作為預設值。不建議使用此引數,並且可能會被棄用。

  • in_order (bool, optional) – 如果為 False,資料載入器將不會強制要求批次以先進先出的順序返回。僅當 num_workers > 0 時適用(預設值:True)。

警告

如果使用 spawn 啟動方法,worker_init_fn 不能是無法 pickle 的物件,例如 lambda 函式。有關 PyTorch 中多程序的更多詳細資訊,請參見多程序最佳實踐

警告

len(dataloader) 的啟發式方法基於所用取樣器的長度。當 datasetIterableDataset 時,它會基於 len(dataset) / batch_size 返回一個估計值,並根據 drop_last 進行適當的四捨五入,而無論多程序載入配置如何。這代表了 PyTorch 所能做出的最佳猜測,因為 PyTorch 相信使用者的資料集程式碼能夠正確處理多程序載入以避免資料重複。

然而,如果分片導致多個工作程序的最後一個批次不完整,這個估計仍然可能不準確,因為 (1) 一個原本完整的批次可能會被分成多個不完整批次,並且 (2) 當設定了 drop_last 時,可能會丟棄超過一個批次的樣本。遺憾的是,PyTorch 通常無法檢測到這種情況。

有關這兩種型別資料集的更多詳細資訊,請參見資料集型別,以及IterableDataset 如何與多程序資料載入互動。

警告

在資料不平衡的情況下,將 in_order 設定為 False 可能會損害可重複性,並可能導致輸入訓練器的資料分佈不均勻。

class torch.utils.data.Dataset[source][source]

一個表示 Dataset 的抽象類。

所有表示從鍵到資料樣本的對映的資料集都應該繼承此類。所有子類都應該重寫 __getitem__() 方法,支援根據給定的鍵獲取資料樣本。子類還可以選擇重寫 __len__() 方法,許多 Sampler 實現和 DataLoader 的預設選項都期望此方法返回資料集的大小。子類還可以選擇實現 __getitems__() 方法,以加速批次樣本載入。此方法接受一個批次的樣本索引列表,並返回樣本列表。

注意

DataLoader 預設構造一個索引取樣器,該取樣器產生整數索引。為了使其與具有非整數索引/鍵的 map 風格資料集一起使用,必須提供自定義取樣器。

class torch.utils.data.IterableDataset[source][source]

一個可迭代的資料集。

所有表示資料樣本的可迭代物件的資料集都應該繼承此類。當資料來自流時,這種形式的資料集特別有用。

所有子類都應該重寫 __iter__() 方法,它將返回此資料集中樣本的迭代器。

當子類與 DataLoader 一起使用時,資料集中的每個項將從 DataLoader 迭代器中生成。當 num_workers > 0 時,每個工作程序將擁有資料集物件的不同副本,因此通常需要獨立配置每個副本以避免工作程序返回重複的資料。在工作程序中呼叫 get_worker_info() 時,會返回關於該工作程序的資訊。它可以用於資料集的 __iter__() 方法或 DataLoaderworker_init_fn 選項中,以修改每個副本的行為。

示例 1:在 __iter__() 中將工作負載分攤到所有工作程序

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]

>>> # Multi-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

示例 2:使用 worker_init_fn 將工作負載分攤到所有工作程序

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
class torch.utils.data.TensorDataset(*tensors)[source][source]

封裝張量的資料集。

每個樣本將透過沿第一個維度索引張量來檢索。

引數

*tensors (Tensor) – 具有相同第一維大小的張量。

class torch.utils.data.StackDataset(*args, **kwargs)[source][source]

將多個數據集堆疊而成的資料集。

當輸入資料複雜且由多個數據集組成時,此類有助於將其不同部分組合起來。

示例

>>> images = ImageDataset()
>>> texts = TextDataset()
>>> tuple_stack = StackDataset(images, texts)
>>> tuple_stack[0] == (images[0], texts[0])
>>> dict_stack = StackDataset(image=images, text=texts)
>>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
引數
  • *args (Dataset) – 作為元組返回的用於堆疊的資料集。

  • **kwargs (Dataset) – 作為字典返回的用於堆疊的資料集。

class torch.utils.data.ConcatDataset(datasets)[source][source]

將多個數據集連線而成的資料集。

此類有助於組合不同的現有資料集。

引數

datasets (sequence) – 要連線的資料集列表。

class torch.utils.data.ChainDataset(datasets)[source][source]

用於連線多個 IterableDataset 的資料集。

此類有助於組合不同的現有資料集流。連線操作是即時完成的,因此使用此類連線大型資料集將非常高效。

引數

資料集 (datasets) (IterableDataset 的 iterable) – 需要鏈式組合的資料集

class torch.utils.data.Subset(dataset, indices)[source][source]

資料集在指定索引處的子集。

引數
  • dataset (Dataset) – 整個資料集

  • indices (sequence) – 為子集選擇的整個集合中的索引

torch.utils.data._utils.collate.collate(batch, *, collate_fn_map=None)[source][source]

處理批處理中集合型別元素的通用整理(collate)函式。

此函式還開放了函式登錄檔,用於處理特定的元素型別。default_collate_fn_map 為張量、NumPy 陣列、數字和字串提供了預設的整理函式。

引數
  • batch – 需要整理的單個批次

  • collate_fn_map (Optional[dict[Union[type, tuple[type, ...]], Callable]]) – 可選字典,將元素型別對映到相應的整理(collate)函式。如果元素型別不在該字典中,此函式將按照插入順序遍歷字典的每個鍵,如果元素型別是該鍵的子類,則呼叫相應的整理函式。

示例

>>> def collate_tensor_fn(batch, *, collate_fn_map):
...     # Extend this function to handle batch of tensors
...     return torch.stack(batch, 0)
>>> def custom_collate(batch):
...     collate_map = {torch.Tensor: collate_tensor_fn}
...     return collate(batch, collate_fn_map=collate_map)
>>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
>>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})

注意

每個整理函式都需要一個用於批處理的位置引數和一個用於整理函式字典的關鍵字引數 collate_fn_map

torch.utils.data.default_collate(batch)[source][source]

接收一批資料,並將批處理中的元素放入一個具有額外外部維度(批大小)的張量中。

確切的輸出型別可以是 torch.Tensortorch.TensorSequencetorch.Tensor 的 Collection,或者保持不變,具體取決於輸入型別。當 DataLoader 中定義了 batch_sizebatch_sampler 時,這被用作整理的預設函式。

以下是通用輸入型別(基於批處理中元素的型別)到輸出型別的對映

  • torch.Tensor -> torch.Tensor (增加了一個外部維度:批大小)

  • NumPy Arrays -> torch.Tensor

  • float -> torch.Tensor

  • int -> torch.Tensor

  • str -> str (不變)

  • bytes -> bytes (不變)

  • Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]

  • NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

  • Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

引數

batch – 需要整理的單個批次

示例

>>> # Example with a batch of `int`s:
>>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3])
>>> # Example with a batch of `str`s:
>>> default_collate(['a', 'b', 'c'])
['a', 'b', 'c']
>>> # Example with `Map` inside the batch:
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
{'A': tensor([  0, 100]), 'B': tensor([  1, 100])}
>>> # Example with `NamedTuple` inside the batch:
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1]))
>>> # Example with `Tuple` inside the batch:
>>> default_collate([(0, 1), (2, 3)])
[tensor([0, 2]), tensor([1, 3])]
>>> # Example with `List` inside the batch:
>>> default_collate([[0, 1], [2, 3]])
[tensor([0, 2]), tensor([1, 3])]
>>> # Two options to extend `default_collate` to handle specific type
>>> # Option 1: Write custom collate function and invoke `default_collate`
>>> def custom_collate(batch):
...     elem = batch[0]
...     if isinstance(elem, CustomType):  # Some custom condition
...         return ...
...     else:  # Fall back to `default_collate`
...         return default_collate(batch)
>>> # Option 2: In-place modify `default_collate_fn_map`
>>> def collate_customtype_fn(batch, *, collate_fn_map=None):
...     return ...
>>> default_collate_fn_map.update(CustomType, collate_customtype_fn)
>>> default_collate(batch)  # Handle `CustomType` automatically
torch.utils.data.default_convert(data)[source][source]

將每個 NumPy 陣列元素轉換為 torch.Tensor

如果輸入是 SequenceCollectionMapping,它會嘗試將內部的每個元素轉換為 torch.Tensor。如果輸入不是 NumPy 陣列,則保持不變。當 DataLoader 中未定義 batch_samplerbatch_size 時,這被用作整理的預設函式。

通用的輸入型別到輸出型別的對映與 default_collate() 類似。有關更多詳細資訊,請參閱那裡的描述。

引數

data – 需要轉換的單個數據點

示例

>>> # Example with `int`
>>> default_convert(0)
0
>>> # Example with NumPy array
>>> default_convert(np.array([0, 1]))
tensor([0, 1])
>>> # Example with NamedTuple
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_convert(Point(0, 0))
Point(x=0, y=0)
>>> default_convert(Point(np.array(0), np.array(0)))
Point(x=tensor(0), y=tensor(0))
>>> # Example with List
>>> default_convert([np.array([0, 1]), np.array([2, 3])])
[tensor([0, 1]), tensor([2, 3])]
torch.utils.data.get_worker_info()[source][source]

返回有關當前 DataLoader 迭代器工作程序的資訊。

在 worker 中呼叫時,此函式返回一個保證具有以下屬性的物件:

  • id: 當前 worker ID。

  • num_workers: worker 總數。

  • seed: 為當前 worker 設定的隨機種子。此值由主程序 RNG 和 worker ID 確定。有關更多詳細資訊,請參閱 DataLoader 的文件。

  • dataset: 程序中資料集物件的副本。請注意,這與主程序中的物件是不同的物件。

在主程序中呼叫時,此函式返回 None

注意

在傳遞給 DataLoaderworker_init_fn 中使用時,此方法可用於以不同方式設定每個 worker 程序,例如,使用 worker_id 配置 dataset 物件,使其僅讀取分片資料集的特定部分,或使用 seed 為資料集程式碼中使用的其他庫設定種子。

返回型別

Optional[WorkerInfo]

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)[source][source]

將資料集隨機拆分為給定長度的互不重疊的新資料集。

如果給定一個總和為 1 的分數列表,則長度將根據提供的每個分數自動計算為 floor(frac * len(dataset))。

計算長度後,如果存在任何餘數,將以迴圈方式向長度分配 1 個計數,直到沒有餘數。

可以選擇固定生成器以獲得可復現的結果,例如:

示例

>>> generator1 = torch.Generator().manual_seed(42)
>>> generator2 = torch.Generator().manual_seed(42)
>>> random_split(range(10), [3, 7], generator=generator1)
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
引數
  • dataset (Dataset) – 需要拆分的資料集

  • lengths (sequence) – 需要生成的拆分的長度或分數

  • generator (Generator) – 用於隨機排列的生成器。

返回型別

list[torch.utils.data.dataset.Subset[~_T]]

class torch.utils.data.Sampler(data_source=None)[source][source]

所有 Sampler 的基類。

每個 Sampler 子類都必須提供 __iter__() 方法,提供一種迭代資料集元素索引或索引列表(批處理)的方式,並且可以提供 __len__() 方法,返回返回的迭代器的長度。

引數

data_source (Dataset) – 此引數未使用,並將在 2.2.0 中移除。您仍可能擁有使用它的自定義實現。

示例

>>> class AccedingSequenceLengthSampler(Sampler[int]):
>>>     def __init__(self, data: List[str]) -> None:
>>>         self.data = data
>>>
>>>     def __len__(self) -> int:
>>>         return len(self.data)
>>>
>>>     def __iter__(self) -> Iterator[int]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         yield from torch.argsort(sizes).tolist()
>>>
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
>>>     def __init__(self, data: List[str], batch_size: int) -> None:
>>>         self.data = data
>>>         self.batch_size = batch_size
>>>
>>>     def __len__(self) -> int:
>>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
>>>
>>>     def __iter__(self) -> Iterator[List[int]]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
>>>             yield batch.tolist()

注意

DataLoader 並非嚴格要求 __len__() 方法,但在涉及 DataLoader 長度的任何計算中都期望此方法。

class torch.utils.data.SequentialSampler(data_source)[source][source]

順序取樣元素,始終按相同順序。

引數

data_source (Dataset) – 從中取樣的資料集

class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[source][source]

隨機取樣元素。如果不允許替換,則從打亂的資料集中取樣。

如果允許替換,使用者可以指定 num_samples 進行抽取。

引數
  • data_source (Dataset) – 從中取樣的資料集

  • replacement (bool) – 如果為 True,則按需有放回取樣,預設值為 False

  • num_samples (int) – 抽樣的數量,預設值為 `len(dataset)`。

  • generator (Generator) – 取樣時使用的生成器。

class torch.utils.data.SubsetRandomSampler(indices, generator=None)[source][source]

從給定的索引列表中隨機取樣元素,不允許替換。

引數
  • indices (sequence) – 索引序列

  • generator (Generator) – 取樣時使用的生成器。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)[source][source]

以給定的機率(權重)從 [0,..,len(weights)-1] 中取樣元素。

引數
  • weights (sequence) – 權重序列,不必總和為一

  • num_samples (int) – 抽樣的數量

  • replacement (bool) – 如果為 True,則有放回取樣。否則,則無放回取樣,這意味著當某行的樣本索引被抽取後,該行不能再次抽取該索引。

  • generator (Generator) – 取樣時使用的生成器。

示例

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[source][source]

包裝另一個 sampler 以生成 mini-batch 索引。

引數
  • sampler (Sampler or Iterable) – 基礎 sampler。可以是任何可迭代物件

  • batch_size (int) – mini-batch 的大小。

  • drop_last (bool) – 如果為 True,如果最後一個批次的大小小於 batch_size,則 sampler 將丟棄最後一個批次

示例

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[source][source]

將資料載入限制到資料集子集的 sampler。

它在使用 torch.nn.parallel.DistributedDataParallel 時尤其有用。在這種情況下,每個程序可以將 DistributedSampler 例項作為 DataLoader sampler 傳遞,並載入原始資料集中獨屬於它的子集。

注意

假設資料集大小恆定,並且其任何例項始終以相同順序返回相同元素。

引數
  • dataset (Dataset) – 用於取樣的訓練集。

  • num_replicas (int, optional) – 參與分散式訓練的程序數。預設情況下,從當前的分散式組中檢索 world_size

  • rank (int, optional) – 當前程序在 num_replicas 中的 rank。預設情況下,從當前的分散式組中檢索 rank

  • shuffle (bool, optional) – 如果為 True(預設),sampler 將打亂索引。

  • seed (int, optional) – 如果 shuffle=True,用於打亂 sampler 的隨機種子。此數字在分散式組中的所有程序之間應相同。預設值:0

  • drop_last (bool, optional) – 如果為 True,則 sampler 將丟棄資料尾部,使其可以均勻地分配給副本數。如果為 False,sampler 將新增額外索引,使資料可以均勻地分配給副本。預設值:False

警告

在分散式模式下,在建立 DataLoader 迭代器之前,在每個 epoch 開始時呼叫 set_epoch() 方法對於在多個 epoch 中使打亂正常工作是必要的。否則,將始終使用相同的順序。

示例

>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
...                     sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)

文件

訪問 PyTorch 的綜合開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源