捷徑

torch.utils.data

PyTorch 資料載入工具的核心是 torch.utils.data.DataLoader 類別。它表示資料集上的 Python 可迭代物件,支援

這些選項由 DataLoader 的建構函式參數配置,其簽名為

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

以下各節詳細說明這些選項的效果和用法。

資料集類型

DataLoader 建構函式最重要的參數是 dataset,它表示要從中載入資料的資料集物件。PyTorch 支援兩種不同類型的資料集

映射風格的資料集

映射風格的資料集是實作 __getitem__()__len__() 協定的資料集,並表示從(可能是非整數)索引/鍵到資料範例的映射。

例如,當使用 dataset[idx] 存取此類資料集時,可以從磁碟上的資料夾中讀取第 idx 個影像及其對應的標籤。

如需詳細資訊,請參閱 Dataset

可迭代風格的資料集

可迭代風格的資料集是實作 __iter__() 協定的 IterableDataset 子類別的實例,並表示資料範例上的可迭代物件。這種類型的資料集特別適用於隨機讀取成本高昂甚至不可能的情況,以及批次大小取決於提取資料的情況。

例如,當呼叫 iter(dataset) 時,此類資料集可以返回從資料庫、遠端伺服器甚至即時產生的日誌中讀取的資料流。

如需詳細資訊,請參閱 IterableDataset

備註

當使用 IterableDataset 進行 多程序資料載入 時。相同的資料集物件會複製到每個工作程序上,因此必須以不同的方式配置副本以避免資料重複。如需如何實現此目的,請參閱 IterableDataset 文件。

資料載入順序和 Sampler

對於 可迭代風格的資料集,資料載入順序完全由使用者定義的可迭代物件控制。這允許更容易地實作區塊讀取和動態批次大小(例如,每次產生一個批次處理的範例)。

本節的其餘部分涉及 映射風格的資料集 的情況。torch.utils.data.Sampler 類別用於指定資料載入中使用的索引/鍵序列。它們表示資料集索引上的可迭代物件。例如,在使用隨機梯度下降 (SGD) 的常見情況下,Sampler 可以隨機排列索引列表並一次產生一個索引,或者為迷你批次 SGD 產生少量索引。

將根據 DataLoadershuffle 參數自動建構順序或隨機播放的取樣器。或者,使用者可以使用 sampler 參數指定自訂 Sampler 物件,該物件每次產生要提取的下一個索引/鍵。

可以將自訂的 Sampler(一次產生一批批次索引清單)作為 batch_sampler 參數傳遞。也可以透過 batch_sizedrop_last 參數啟用自動批次處理。如需更多詳細資訊,請參閱下一節

備註

samplerbatch_sampler 皆與可迭代樣式資料集不相容,因為此類資料集沒有索引鍵或索引的概念。

載入批次和非批次資料

DataLoader 支援透過參數 batch_sizedrop_lastbatch_samplercollate_fn(具有預設函數)自動將個別擷取的資料範例整理到批次中。

自動批次處理(預設)

這是最常見的情況,對應於擷取一個小批次資料並將其整理到批次範例中,亦即,包含一個維度為批次維度(通常是第一個)的張量。

batch_size(預設為 1)不是 None 時,資料載入器會產生批次範例,而不是個別範例。batch_sizedrop_last 參數用於指定資料載入器如何取得資料集索引鍵的批次。對於映射樣式資料集,使用者可以選擇指定 batch_sampler,它一次產生一個索引鍵清單。

備註

batch_sizedrop_last 參數基本上用於從 sampler 建構 batch_sampler。對於映射樣式資料集,sampler 由使用者提供或根據 shuffle 參數建構。對於可迭代樣式資料集,sampler 是一個虛擬的無限 sampler。如需有關取樣器的更多詳細資訊,請參閱本節

備註

當使用多程序可迭代樣式資料集擷取時,drop_last 參數會刪除每個工作程序的資料集複本中最後一個非完整批次。

使用取樣器中的索引擷取範例清單後,會使用作為 collate_fn 參數傳遞的函數將範例清單整理到批次中。

在這種情況下,從映射樣式資料集載入大致相當於

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

而從可迭代樣式資料集載入大致相當於

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

可以使用自訂的 collate_fn來自訂整理,例如,將序列資料填充到批次的最大長度。如需有關 collate_fn 的更多資訊,請參閱本節

停用自動批次處理

在某些情況下,使用者可能希望在資料集程式碼中手動處理批次處理,或者只是載入個別範例。例如,直接載入批次資料(例如,從資料庫大量讀取或讀取連續的記憶體區塊)可能更便宜,或者批次大小取決於資料,或者程式設計為處理個別範例。在這些情況下,最好不要使用自動批次處理(其中使用 collate_fn 來整理範例),而是讓資料載入器直接返回 dataset 物件的每個成員。

batch_sizebatch_sampler 都是 None 時(batch_sampler 的預設值已經是 None),就會停用自動批次處理。從 dataset 取得的每個範例都會使用作為 collate_fn 參數傳遞的函數進行處理。

**當自動批次處理被停用時**,預設的 collate_fn 只會將 NumPy 陣列轉換為 PyTorch 張量,並保持其他所有內容不變。

在這種情況下,從映射樣式資料集載入大致相當於

for index in sampler:
    yield collate_fn(dataset[index])

而從可迭代樣式資料集載入大致相當於

for data in iter(dataset):
    yield collate_fn(data)

如需有關 collate_fn 的更多資訊,請參閱本節

使用 collate_fn

當啟用或停用自動批次處理時,collate_fn 的用途略有不同。

**當自動批次處理被停用時**,會使用每個個別資料範例呼叫 collate_fn,並從資料載入器迭代器產生輸出。在這種情況下,預設的 collate_fn 只會將 PyTorch 張量中的 NumPy 陣列轉換。

**當啟用自動批次處理時**,每次都會使用資料範例清單呼叫 collate_fn。預計會將輸入範例整理到批次中,以便從資料載入器迭代器產生。本節的其餘部分描述了預設 collate_fn (default_collate()) 的行為。

例如,如果每個資料範例都包含一個 3 通道影像和一個整數類別標籤,亦即,資料集的每個元素都返回一個元組 (image, class_index),則預設的 collate_fn 會將此類元組的清單整理到一個批次影像張量和一個批次類別標籤張量的單一元組中。特別是,預設的 collate_fn 具有以下屬性

  • 它始終會將一個新維度作為批次維度。

  • 它會自動將 NumPy 陣列和 Python 數值轉換為 PyTorch 張量。

  • 它會保留資料結構,例如,如果每個範例都是一個字典,它會輸出一個具有相同索引鍵集的字典,但值為批次張量(如果值無法轉換為張量,則為清單)。listtuplenamedtuple 等也是如此。

使用者可以使用自訂的 collate_fn 來實現自訂批次處理,例如,沿著第一個維度以外的維度進行整理、填充各種長度的序列或新增對自訂資料類型的支援。

如果您遇到 DataLoader 的輸出具有與您的預期不同的維度或類型的狀況,您可能需要檢查您的 collate_fn

單程序和多程序資料載入

DataLoader 預設使用單程序資料載入。

在 Python 程序中,全域直譯器鎖定(GIL)會阻止 Python 程式碼在執行緒間真正完全平行化。為了避免使用資料載入來阻塞計算程式碼,PyTorch 提供了一個簡單的切換方式,只需將參數 num_workers 設定為正整數,即可執行多程序資料載入。

單程序資料載入(預設)

在此模式下,資料擷取是在初始化 DataLoader 的相同程序中完成的。因此,資料載入可能會阻塞計算。但是,當用於在程序之間共享資料的資源(例如,共享記憶體、檔案描述符)有限時,或者當整個資料集很小並且可以完全載入到記憶體中時,可能會優先使用此模式。此外,單程序載入通常會顯示更易讀的錯誤追蹤,因此有助於偵錯。

多程序資料載入

將參數 num_workers 設定為正整數將使用指定的載入器工作程序數量開啟多程序資料載入。

警告

經過多次迭代之後,對於從 worker 進程存取的父進程中的所有 Python 物件,loader worker 進程將會消耗與父進程相同數量的 CPU 記憶體。如果資料集包含大量資料(例如,您在資料集建構時載入了一個非常大的檔名清單)和/或您正在使用大量 worker(整體記憶體使用量為 worker 數量 * 父進程大小),這可能會產生問題。最簡單的解決方法是用非引用計數表示法(例如 Pandas、Numpy 或 PyArrow 物件)替換 Python 物件。請查看問題 #13246,以取得有關發生此問題的原因以及如何解決這些問題的範例程式碼的詳細資訊。

在此模式下,每次建立 DataLoader 的迭代器時(例如,當您呼叫 enumerate(dataloader) 時),就會建立 num_workers 個 worker 進程。此時,datasetcollate_fnworker_init_fn 會被傳遞到每個 worker,它們會在 worker 中用於初始化和提取資料。這表示資料集存取及其內部 IO、轉換(包括 collate_fn)會在 worker 進程中執行。

torch.utils.data.get_worker_info() 會在 worker 進程中傳回各種有用的資訊(包括 worker ID、資料集副本、初始種子等),並在主進程中傳回 None。使用者可以在資料集程式碼和/或 worker_init_fn 中使用此函數來個別設定每個資料集副本,並判斷程式碼是否在 worker 進程中執行。例如,這在對資料集進行分片時特別有用。

對於映射樣式資料集,主進程使用 sampler 產生索引並將其發送到 worker。因此,任何隨機排序操作都是在主進程中完成的,主進程會透過分配索引來引導載入。

對於可迭代樣式資料集,由於每個 worker 進程都會取得 dataset 物件的副本,因此單純的多進程載入通常會導致資料重複。使用 torch.utils.data.get_worker_info() 和/或 worker_init_fn,使用者可以獨立設定每個副本。(請參閱 IterableDataset 文件,瞭解如何實現這一點。)出於類似的原因,在多進程載入中,drop_last 參數會捨棄每個 worker 的可迭代樣式資料集副本的最後一個非完整批次。

一旦迭代結束,或者迭代器被垃圾回收,worker 就會被關閉。

警告

通常不建議在多進程載入中傳回 CUDA 張量,因為在多進程中使用 CUDA 和共享 CUDA 張量存在許多微妙之處(請參閱 多進程中的 CUDA)。相反,我們建議使用 自動記憶體固定(即設定 pin_memory=True),這可以實現快速將資料傳輸到支援 CUDA 的 GPU。

平台特定行為

由於 worker 依賴於 Python multiprocessing,因此 Windows 上的 worker 啟動行為與 Unix 不同。

  • 在 Unix 上,fork() 是預設的 multiprocessing 啟動方法。使用 fork(),子 worker 通常可以透過複製的位址空間直接存取 dataset 和 Python 參數函數。

  • 在 Windows 或 MacOS 上,spawn() 是預設的 multiprocessing 啟動方法。使用 spawn(),會啟動另一個直譯器來執行您的主指令碼,然後執行內部 worker 函數,該函數會透過 pickle 序列化接收 datasetcollate_fn 和其他參數。

這種單獨的序列化意味著您應該採取兩個步驟,以確保在使用多進程資料載入時與 Windows 相容

  • 將大部分主指令碼的程式碼包裝在 if __name__ == '__main__': 區塊中,以確保在啟動每個 worker 進程時不會再次執行它(很可能會產生錯誤)。您可以將資料集和 DataLoader 實例建立邏輯放在這裡,因為它不需要在 worker 中重新執行。

  • 確保任何自訂的 collate_fnworker_init_fndataset 程式碼都宣告為頂級定義,位於 __main__ 檢查之外。這確保它們在 worker 進程中可用。(這是必要的,因為函數僅作為參考進行醃製,而不是 bytecode。)

多進程資料載入中的隨機性

根據預設,每個 worker 的 PyTorch 種子都會設定為 base_seed + worker_id,其中 base_seed 是由主進程使用其 RNG 產生的一個長整數(因此,強制消耗 RNG 狀態)或指定的 generator。但是,在初始化 worker 時,其他函式庫的種子可能會重複,導致每個 worker 傳回相同的隨機數。(請參閱常見問題解答中的 此部分)。

worker_init_fn 中,您可以使用 torch.utils.data.get_worker_info().seedtorch.initial_seed() 存取為每個 worker 設定的 PyTorch 種子,並使用它在載入資料之前為其他函式庫設定種子。

記憶體固定

當主機到 GPU 的複製源自固定(頁面鎖定)記憶體時,速度會快得多。有關何時以及如何使用固定記憶體的詳細資訊,請參閱 使用固定記憶體緩衝區

對於資料載入,將 pin_memory=True 傳遞到 DataLoader 會自動將提取的資料張量放入固定記憶體中,從而實現更快地將資料傳輸到支援 CUDA 的 GPU。

預設的記憶體固定邏輯只會辨識張量以及包含張量的映射和可迭代物件。根據預設,如果固定邏輯看到一個自訂類型的批次(如果您有一個傳回自訂批次類型的 collate_fn,就會發生這種情況),或者如果批次的每個元素都是自訂類型,那麼固定邏輯就不會辨識它們,並且會在不固定記憶體的情況下傳回該批次(或這些元素)。若要為自訂批次或資料類型啟用記憶體固定,請在您的自訂類型上定義 pin_memory() 方法。

請參閱以下範例。

範例

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')[source]

資料載入器結合了資料集和取樣器,並提供給定資料集上的迭代器。

DataLoader 支援單一或多程序載入的映射樣式和可迭代樣式資料集,可自訂載入順序,並提供可選的自動批次處理(排序)和記憶體固定。

如需更多詳細資訊,請參閱 torch.utils.data 文件頁面。

參數
  • dataset (Dataset) – 要從中載入資料的資料集。

  • batch_size (int, 可選) – 每個批次要載入多少個樣本(預設值:1)。

  • shuffle (bool, 可選) – 設定為 True 可在每個 epoch 重新洗牌資料(預設值:False)。

  • sampler (SamplerIterable, 可選) – 定義從資料集中提取樣本的策略。可以是任何已實作 __len__Iterable。如果指定,則不得指定 shuffle

  • batch_sampler (SamplerIterable, 可選) – 類似於 sampler,但一次返回一批索引。與 batch_sizeshufflesamplerdrop_last 互斥。

  • num_workers (int, 可選) – 用於資料載入的子程序數。 0 表示資料將在主程序中載入。(預設值:0

  • collate_fn (Callable, 可選) – 合併樣本清單以形成 Tensor 的 minibatch。在使用映射樣式資料集的批次載入時使用。

  • pin_memory (bool, 可選) – 如果為 True,資料載入器會在返回 Tensor 之前將其複製到裝置/CUDA 固定記憶體中。如果您的資料元素是自訂型別,或者您的 collate_fn 返回的批次是自訂型別,請參閱下面的範例。

  • drop_last (bool, 可選) – 如果資料集大小無法被批次大小整除,則設定為 True 以捨棄最後一個不完整的批次。如果為 False 且資料集大小無法被批次大小整除,則最後一個批次會比較小。(預設值:False

  • timeout (數值, 可選) – 如果為正數,則表示從工作程序收集批次的逾時值。應始終為非負數。(預設值:0

  • worker_init_fn (Callable, 可選) – 如果不是 None,則在植入種子之後、資料載入之前,將使用工作程序 ID([0, num_workers - 1] 中的 int)作為輸入,在每個工作程序子程序上呼叫此函數。(預設值:None

  • multiprocessing_context (strmultiprocessing.context.BaseContext, 可選) – 如果為 None,則將使用您作業系統的預設 多程序環境。(預設值:None

  • generator (torch.Generator, 可選) – 如果不是 None,RandomSampler 將使用此 RNG 產生隨機索引,而多程序將使用此 RNG 為工作程序產生 base_seed。(預設值:None

  • prefetch_factor (int, 可選, 僅限關鍵字參數) – 每個工作程序預先載入的批次數。 2 表示所有工作程序總共會預先提取 2 * num_workers 個批次。(預設值取決於 num_workers 的設定值。如果 num_workers=0,則預設值為 None。否則,如果 num_workers > 0,則預設值為 2)。

  • persistent_workers (bool, 可選) – 如果為 True,則資料載入器在資料集被取用一次後不會關閉工作程序。這允許保持工作程序的 Dataset 實例處於活動狀態。(預設值:False

  • pin_memory_device (str, 可選) – 如果 pin_memoryTrue,則為要將 pin_memory 設定到的裝置。

警告

如果使用 spawn 啟動方法,則 worker_init_fn 不能是無法醃製的物件,例如 lambda 函數。如需與 PyTorch 中的多程序處理相關的更多詳細資訊,請參閱 多程序處理最佳實務

警告

len(dataloader) 啟發式方法是基於所用取樣器的長度。當 datasetIterableDataset 時,它會根據 len(dataset) / batch_size 返回一個估計值,並根據 drop_last 進行適當的捨入,而不管多程序載入配置為何。這表示 PyTorch 可以做出的最佳猜測,因為 PyTorch 信任使用者 dataset 程式碼能夠正確處理多程序載入以避免資料重複。

但是,如果分片導致多個工作程序具有不完整的最後一個批次,則此估計值可能仍然不準確,因為 (1) 否則完整的批次可能會被分解成多個批次,以及 (2) 當設定 drop_last 時,可能會捨棄超過一個批次的樣本。遺憾的是,PyTorch 通常無法偵測到此類情況。

如需這兩種資料集類型以及 IterableDataset 如何與 多程序資料載入 互動的更多詳細資訊,請參閱 資料集類型

警告

有關隨機種子的相關問題,請參閱重現性我的資料載入器工作程序傳回相同的隨機數多程序資料載入中的隨機性說明。

class torch.utils.data.Dataset(*args, **kwds)[來源]

表示Dataset的抽象類別。

所有表示從鍵到資料樣本的映射的資料集都應該繼承它。所有子類別都應該覆寫__getitem__(),支持獲取給定鍵的資料樣本。子類別也可以選擇覆寫__len__(),預期它會傳回許多Sampler實作和DataLoader的默認選項的資料集大小。子類別也可以選擇實作__getitems__(),以加速批次樣本載入。此方法接受批次樣本的索引清單,並傳回樣本清單。

備註

DataLoader預設會建構一個產生整數索引的索引取樣器。若要使其與具有非整數索引/鍵的映射樣式資料集一起使用,則必須提供自定義取樣器。

class torch.utils.data.IterableDataset(*args, **kwds)[來源]

可迭代的Dataset。

所有表示資料樣本可迭代物件的資料集都應該繼承它。當資料來自串流時,這種形式的資料集特別有用。

所有子類別都應該覆寫__iter__(),它會傳回此資料集中樣本的迭代器。

當子類別與DataLoader一起使用時,資料集中的每個項目都將從DataLoader迭代器中產生。當num_workers > 0時,每個工作程序都將擁有資料集物件的不同副本,因此通常希望獨立配置每個副本,以避免從工作程序傳回重複的資料。在工作程序中呼叫get_worker_info()時,它會傳回有關工作程序的資訊。它可以在資料集的__iter__()方法或DataLoaderworker_init_fn選項中使用,以修改每個副本的行為。

範例1:在__iter__()中跨所有工作程序分割工作負載

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]

>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

範例2:使用worker_init_fn跨所有工作程序分割工作負載

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
class torch.utils.data.TensorDataset(*tensors)[來源]

包裝張量的資料集。

每個樣本都將透過沿著第一個維度索引張量來檢索。

參數

*tensors (Tensor) – 第一個維度大小相同的張量。

class torch.utils.data.StackDataset(*args, **kwargs)[來源]

資料集作為多個資料集的堆疊。

這個類別對於組裝複雜輸入資料的不同部分很有用,這些部分以資料集的形式給出。

範例

>>> images = ImageDataset()
>>> texts = TextDataset()
>>> tuple_stack = StackDataset(images, texts)
>>> tuple_stack[0] == (images[0], texts[0])
>>> dict_stack = StackDataset(image=images, text=texts)
>>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
參數
  • *args (Dataset) – 作為元組傳回的堆疊資料集。

  • **kwargs (Dataset) – 作為字典傳回的堆疊資料集。

class torch.utils.data.ConcatDataset(datasets)[來源]

資料集作為多個資料集的串聯。

這個類別對於組裝不同的現有資料集很有用。

參數

datasets (序列) – 要串聯的資料集清單

class torch.utils.data.ChainDataset(datasets)[來源]

用於鏈接多個IterableDataset的資料集。

這個類別對於組裝不同的現有資料集串流很有用。鏈接操作是動態完成的,因此使用此類別串聯大型資料集將會很有效率。

參數

datasets (IterableDataset可迭代物件) – 要鏈接在一起的資料集

class torch.utils.data.Subset(dataset, indices)[來源]

指定索引處的資料集子集。

參數
  • dataset (Dataset) – 整個資料集

  • indices (序列) – 為子集選擇的整個集合中的索引

torch.utils.data._utils.collate.collate(batch, *, collate_fn_map=None)[來源]

處理每個批次中元素集合類型的通用整理函數。

該函數還打開了函數註冊表來處理特定的元素類型。default_collate_fn_map為張量、numpy陣列、數字和字串提供了默認的整理函數。

參數
  • batch – 要整理的單個批次

  • collate_fn_map (Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]]) – 可選的字典,用於從元素類型映射到相應的整理函數。如果元素類型不存在於此字典中,則此函數將按照插入順序遍歷字典的每個鍵,如果元素類型是鍵的子類別,則呼叫相應的整理函數。

範例

>>> def collate_tensor_fn(batch, *, collate_fn_map):
...     # Extend this function to handle batch of tensors
...     return torch.stack(batch, 0)
>>> def custom_collate(batch):
...     collate_map = {torch.Tensor: collate_tensor_fn}
...     return collate(batch, collate_fn_map=collate_map)
>>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
>>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})

備註

每個整理函數都需要一個用於批次的positional argument,以及一個用於整理函數字典的keyword argument作為collate_fn_map

torch.utils.data.default_collate(batch)[來源]

接收一批資料,並將批次中的元素放入具有額外外部維度的張量中 - 批次大小。

確切的輸出類型可以是torch.Tensortorch.Tensor序列torch.Tensor集合,或者保持不變,具體取決於輸入類型。當在DataLoader中定義了batch_sizebatch_sampler時,這將用作整理的默認函數。

以下是根據批次中元素的類型,從輸入類型到輸出類型的一般映射。

  • torch.Tensor -> torch.Tensor(添加了一個外部維度批次大小)

  • NumPy 陣列 -> torch.Tensor

  • float -> torch.Tensor

  • int -> torch.Tensor

  • str -> str(不變)

  • bytes -> bytes(不變)

  • Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]

  • NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

  • Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

參數

batch – 要整理的單個批次

範例

>>> # Example with a batch of `int`s:
>>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3])
>>> # Example with a batch of `str`s:
>>> default_collate(['a', 'b', 'c'])
['a', 'b', 'c']
>>> # Example with `Map` inside the batch:
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
{'A': tensor([  0, 100]), 'B': tensor([  1, 100])}
>>> # Example with `NamedTuple` inside the batch:
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1]))
>>> # Example with `Tuple` inside the batch:
>>> default_collate([(0, 1), (2, 3)])
[tensor([0, 2]), tensor([1, 3])]
>>> # Example with `List` inside the batch:
>>> default_collate([[0, 1], [2, 3]])
[tensor([0, 2]), tensor([1, 3])]
>>> # Two options to extend `default_collate` to handle specific type
>>> # Option 1: Write custom collate function and invoke `default_collate`
>>> def custom_collate(batch):
...     elem = batch[0]
...     if isinstance(elem, CustomType):  # Some custom condition
...         return ...
...     else:  # Fall back to `default_collate`
...         return default_collate(batch)
>>> # Option 2: In-place modify `default_collate_fn_map`
>>> def collate_customtype_fn(batch, *, collate_fn_map=None):
...     return ...
>>> default_collate_fn_map.update(CustomType, collate_customtype_fn)
>>> default_collate(batch)  # Handle `CustomType` automatically
torch.utils.data.default_convert(data)[原始碼]

將每個 NumPy 陣列元素轉換為 torch.Tensor

如果輸入是 SequenceCollectionMapping,它會嘗試將內部的每個元素轉換為 torch.Tensor。如果輸入不是 NumPy 陣列,則保持不變。當 batch_samplerbatch_sizeDataLoader 中都未定義時,這將用作排序的預設函數。

從輸入類型到輸出類型的一般映射類似於 default_collate()。有關更多詳細信息,請參見那裡的描述。

參數

data – 要轉換的單個數據點

範例

>>> # Example with `int`
>>> default_convert(0)
0
>>> # Example with NumPy array
>>> default_convert(np.array([0, 1]))
tensor([0, 1])
>>> # Example with NamedTuple
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_convert(Point(0, 0))
Point(x=0, y=0)
>>> default_convert(Point(np.array(0), np.array(0)))
Point(x=tensor(0), y=tensor(0))
>>> # Example with List
>>> default_convert([np.array([0, 1]), np.array([2, 3])])
[tensor([0, 1]), tensor([2, 3])]
torch.utils.data.get_worker_info()[原始碼]

返回有關當前 DataLoader 迭代器工作進程的信息。

在工作進程中調用時,這會返回一個保證具有以下屬性的對象

  • id:當前工作進程的 ID。

  • num_workers:工作進程的總數。

  • seed:為當前工作進程設置的隨機種子。此值由主進程 RNG 和工作進程 ID 決定。有關更多詳細信息,請參見 DataLoader 的文檔。

  • dataset:數據集對象在進程中的副本。請注意,這在不同進程中將是與主進程中不同的對象。

在主進程中調用時,這會返回 None

備註

當在傳遞給 DataLoaderworker_init_fn 中使用時,此方法可用於以不同方式設置每個工作進程,例如,使用 worker_id 配置 dataset 對象以僅讀取分片數據集的特定部分,或使用 seed 為數據集代碼中使用的其他庫設置種子。

返回類型

Optional[WorkerInfo]

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)[原始碼]

將數據集隨機拆分為給定長度的非重疊新數據集。

如果給定一個總和為 1 的分數列表,則將針對提供的每個分數自動計算長度為 floor(frac * len(dataset))。

計算長度後,如果有任何餘數,則將以循環方式將 1 個計數分配給長度,直到沒有餘數為止。

(可選)修復生成器以獲得可重複的結果,例如

範例

>>> generator1 = torch.Generator().manual_seed(42)
>>> generator2 = torch.Generator().manual_seed(42)
>>> random_split(range(10), [3, 7], generator=generator1)
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
參數
  • dataset (Dataset) – 要拆分的數據集

  • lengths (序列) – 要生成的拆分的長度或分數

  • generator (Generator) – 用於隨機排列的生成器。

返回類型

List[Subset[T]]

類別 torch.utils.data.Sampler(data_source=None)[原始碼]

所有採樣器的基類。

每個 Sampler 子類都必須提供一個 __iter__() 方法,提供一種迭代數據集元素的索引或索引列表(批次)的方法,並且可以提供一個 __len__() 方法,該方法返回返回的迭代器的長度。

參數

data_source (Dataset) – 此參數未使用,將在 2.2.0 中刪除。您可能仍有利用它的自定義實現。

範例

>>> class AccedingSequenceLengthSampler(Sampler[int]):
>>>     def __init__(self, data: List[str]) -> None:
>>>         self.data = data
>>>
>>>     def __len__(self) -> int:
>>>         return len(self.data)
>>>
>>>     def __iter__(self) -> Iterator[int]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         yield from torch.argsort(sizes).tolist()
>>>
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
>>>     def __init__(self, data: List[str], batch_size: int) -> None:
>>>         self.data = data
>>>         self.batch_size = batch_size
>>>
>>>     def __len__(self) -> int:
>>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
>>>
>>>     def __iter__(self) -> Iterator[List[int]]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
>>>             yield batch.tolist()

備註

DataLoader 並未嚴格要求 __len__() 方法,但在涉及 DataLoader 長度的任何計算中都需要使用。

類別 torch.utils.data.SequentialSampler(data_source)[原始碼]

按順序對元素進行採樣,始終以相同的順序。

參數

data_source (Dataset) – 要从中採樣的數據集

類別 torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[原始碼]

隨機對元素進行採樣。如果不放回,則從打亂的數據集中採樣。

如果放回,則用戶可以指定 num_samples 來繪製。

參數
  • data_source (Dataset) – 要从中採樣的數據集

  • replacement (bool) – 如果為 True,則按需抽取樣本並放回,默認為 ``False``

  • num_samples (int) – 要繪製的樣本數量,默認值為 `len(dataset)`。

  • generator (Generator) – 採樣中使用的生成器。

類別 torch.utils.data.SubsetRandomSampler(indices, generator=None)[原始碼]

從給定的索引列表中隨機對元素進行採樣,不放回。

參數
  • indices (序列) – 索引序列

  • generator (Generator) – 採樣中使用的生成器。

類別 torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)[原始碼]

以給定的概率(權重)從 [0,..,len(weights)-1] 中對元素進行採樣。

參數
  • weights (序列) – 一個權重的序列,不一定要加總為 1

  • num_samples (int) – 要抽取的樣本數量

  • replacement (bool) – 如果為 True,則會以取後放回的方式抽取樣本。如果不是,則會以不放回的方式抽取,這表示當某一列的樣本索引被抽取後,該列就不能再抽取到相同的索引。

  • generator (Generator) – 採樣中使用的生成器。

範例

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
類別 torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[原始碼]

包裝另一個採樣器以產生一個 mini-batch 的索引。

參數
  • sampler (Sampler可迭代物件) – 基礎採樣器。可以是任何可迭代物件

  • batch_size (int) – mini-batch 的大小。

  • drop_last (bool) – 如果為 True,則當最後一個 batch 的大小小於 batch_size 時,採樣器會捨棄該 batch

範例

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
類別 torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[原始碼]

將資料載入限制為資料集子集的採樣器。

它與 torch.nn.parallel.DistributedDataParallel 結合使用特別有用。在這種情況下,每個處理程序都可以將一個 DistributedSampler instance 作為 DataLoader 採樣器傳遞,並載入原始資料集中專屬於它的子集。

備註

假設資料集的大小是固定的,並且它的任何實例總是會以相同的順序返回相同的元素。

參數
  • dataset – 用於採樣的資料集。

  • num_replicas (int, 選用) – 參與分散式訓練的處理程序數量。預設情況下,會從當前的分散式群組中擷取 world_size

  • rank (int, 選用) – 當前處理程序在 num_replicas 中的排名。預設情況下,會從當前的分散式群組中擷取 rank

  • shuffle (bool, 選用) – 如果為 True(預設值),則採樣器會將索引打亂。

  • seed (int, 選用) – 如果 shuffle=True,則用於打亂採樣器的隨機種子。這個數字在分散式群組的所有處理程序中應該相同。預設值:0

  • drop_last (bool, 選用) – 如果為 True,則採樣器會捨棄資料的尾部,使其在副本數量之間平均分配。如果為 False,則採樣器會添加額外的索引,使資料在副本數量之間平均分配。預設值:False

警告

在分散式模式下,在每個 epoch 開始時建立 DataLoader 迭代器之前,必須呼叫 set_epoch() 方法,才能使打亂功能在多個 epoch 中正常運作。否則,將始終使用相同的排序。

範例

>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
...                     sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)

文件

取得 PyTorch 的完整開發者文件

查看文件

教學

取得適用於初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源