注意
點選此處下載完整的示例程式碼
CTC 強制對齊 API 教程¶
作者: Xiaohui Zhang, Moto Hira
強制對齊是將文字記錄與語音對齊的過程。本教程展示瞭如何使用 torchaudio.functional.forced_align() 將文字記錄與語音對齊,該函式是基於將語音技術擴充套件到 1,000 多種語言的工作開發的。
forced_align() 具有自定義的 CPU 和 CUDA 實現,這些實現比上面的純 Python 實現效能更高、更準確。它還可以透過特殊的 <star> token 處理缺失的文字記錄。
還有一個高階 API,torchaudio.pipelines.Wav2Vec2FABundle,它封裝了本教程中解釋的預處理/後處理,使得執行強制對齊變得容易。多語言資料的強制對齊 使用此 API 來演示如何對齊非英語文字記錄。
準備¶
import torch
import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
2.7.0
2.7.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
import IPython
import matplotlib.pyplot as plt
import torchaudio.functional as F
首先,我們準備將要使用的語音資料和文字記錄。
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
waveform, _ = torchaudio.load(SPEECH_FILE)
TRANSCRIPT = "i had that curiosity beside me at this moment".split()
生成 emissions¶
forced_align() 接收 emission 和 token 序列,並輸出 token 的時間戳和它們的得分。
Emission 表示 token 的逐幀機率分佈,可以透過將波形輸入到聲學模型獲得。
Token 是文字記錄的數值表示。有許多方法可以對文字記錄進行 tokenize,但在這裡,我們只是將字母對映到整數,這是我們在訓練要使用的聲學模型時構建標籤的方式。
我們將使用一個預訓練的 Wav2Vec2 模型,torchaudio.pipelines.MMS_FA,來獲取 emission 並對文字記錄進行 tokenize。
bundle = torchaudio.pipelines.MMS_FA
model = bundle.get_model(with_star=False).to(device)
with torch.inference_mode():
emission, _ = model(waveform.to(device))
Downloading: "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt" to /root/.cache/torch/hub/checkpoints/model.pt
0%| | 0.00/1.18G [00:00<?, ?B/s]
3%|2 | 30.5M/1.18G [00:00<00:03, 319MB/s]
5%|5 | 61.0M/1.18G [00:00<00:03, 317MB/s]
8%|7 | 91.8M/1.18G [00:00<00:03, 319MB/s]
10%|# | 122M/1.18G [00:00<00:03, 308MB/s]
13%|#3 | 161M/1.18G [00:00<00:03, 342MB/s]
16%|#6 | 195M/1.18G [00:00<00:03, 347MB/s]
19%|#9 | 233M/1.18G [00:00<00:02, 365MB/s]
22%|##2 | 268M/1.18G [00:00<00:02, 333MB/s]
26%|##5 | 308M/1.18G [00:00<00:02, 357MB/s]
29%|##8 | 348M/1.18G [00:01<00:02, 375MB/s]
32%|###2 | 390M/1.18G [00:01<00:02, 392MB/s]
36%|###5 | 427M/1.18G [00:01<00:02, 376MB/s]
39%|###8 | 464M/1.18G [00:01<00:02, 364MB/s]
41%|####1 | 499M/1.18G [00:01<00:02, 360MB/s]
44%|####4 | 534M/1.18G [00:01<00:01, 364MB/s]
47%|####7 | 569M/1.18G [00:01<00:01, 360MB/s]
50%|##### | 607M/1.18G [00:01<00:01, 369MB/s]
53%|#####3 | 642M/1.18G [00:01<00:01, 368MB/s]
57%|#####6 | 680M/1.18G [00:01<00:01, 376MB/s]
59%|#####9 | 716M/1.18G [00:02<00:01, 375MB/s]
63%|######2 | 753M/1.18G [00:02<00:01, 377MB/s]
66%|######5 | 789M/1.18G [00:02<00:01, 374MB/s]
69%|######8 | 828M/1.18G [00:02<00:01, 383MB/s]
72%|#######1 | 864M/1.18G [00:02<00:00, 382MB/s]
75%|#######5 | 906M/1.18G [00:02<00:00, 397MB/s]
79%|#######8 | 947M/1.18G [00:02<00:00, 408MB/s]
82%|########1 | 986M/1.18G [00:02<00:00, 408MB/s]
85%|########5 | 1.00G/1.18G [00:02<00:00, 397MB/s]
89%|########8 | 1.04G/1.18G [00:03<00:00, 417MB/s]
92%|#########2| 1.08G/1.18G [00:03<00:00, 372MB/s]
95%|#########5| 1.12G/1.18G [00:03<00:00, 370MB/s]
98%|#########8| 1.15G/1.18G [00:03<00:00, 362MB/s]
100%|##########| 1.18G/1.18G [00:03<00:00, 369MB/s]

對文字記錄進行 tokenize¶
我們建立一個字典,將每個標籤對映到 token。
LABELS = bundle.get_labels(star=None)
DICTIONARY = bundle.get_dict(star=None)
for k, v in DICTIONARY.items():
print(f"{k}: {v}")
-: 0
a: 1
i: 2
e: 3
n: 4
o: 5
u: 6
t: 7
s: 8
r: 9
m: 10
k: 11
l: 12
d: 13
g: 14
h: 15
y: 16
b: 17
p: 18
w: 19
c: 20
v: 21
j: 22
z: 23
f: 24
': 25
q: 26
x: 27
將文字記錄轉換為 token 就像下面這樣簡單:
tokenized_transcript = [DICTIONARY[c] for word in TRANSCRIPT for c in word]
for t in tokenized_transcript:
print(t, end=" ")
print()
2 15 1 13 7 15 1 7 20 6 9 2 5 8 2 7 16 17 3 8 2 13 3 10 3 1 7 7 15 2 8 10 5 10 3 4 7
計算對齊¶
幀級對齊¶
現在我們呼叫 TorchAudio 的強制對齊 API 來計算幀級對齊。有關函式簽名的詳細資訊,請參閱 forced_align()。
def align(emission, tokens):
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
alignments, scores = F.forced_align(emission, targets, blank=0)
alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
scores = scores.exp() # convert back to probability
return alignments, scores
aligned_tokens, alignment_scores = align(emission, tokenized_transcript)
現在讓我們看看輸出。
for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
print(f"{i:3d}:\t{ali:2d} [{LABELS[ali]}], {score:.2f}")
0: 0 [-], 1.00
1: 0 [-], 1.00
2: 0 [-], 1.00
3: 0 [-], 1.00
4: 0 [-], 1.00
5: 0 [-], 1.00
6: 0 [-], 1.00
7: 0 [-], 1.00
8: 0 [-], 1.00
9: 0 [-], 1.00
10: 0 [-], 1.00
11: 0 [-], 1.00
12: 0 [-], 1.00
13: 0 [-], 1.00
14: 0 [-], 1.00
15: 0 [-], 1.00
16: 0 [-], 1.00
17: 0 [-], 1.00
18: 0 [-], 1.00
19: 0 [-], 1.00
20: 0 [-], 1.00
21: 0 [-], 1.00
22: 0 [-], 1.00
23: 0 [-], 1.00
24: 0 [-], 1.00
25: 0 [-], 1.00
26: 0 [-], 1.00
27: 0 [-], 1.00
28: 0 [-], 1.00
29: 0 [-], 1.00
30: 0 [-], 1.00
31: 0 [-], 1.00
32: 2 [i], 1.00
33: 0 [-], 1.00
34: 0 [-], 1.00
35: 15 [h], 1.00
36: 15 [h], 0.93
37: 1 [a], 1.00
38: 0 [-], 0.96
39: 0 [-], 1.00
40: 0 [-], 1.00
41: 13 [d], 1.00
42: 0 [-], 1.00
43: 0 [-], 0.97
44: 7 [t], 1.00
45: 15 [h], 1.00
46: 0 [-], 0.98
47: 1 [a], 1.00
48: 0 [-], 1.00
49: 0 [-], 1.00
50: 7 [t], 1.00
51: 0 [-], 1.00
52: 0 [-], 1.00
53: 0 [-], 1.00
54: 20 [c], 1.00
55: 0 [-], 1.00
56: 0 [-], 1.00
57: 0 [-], 1.00
58: 6 [u], 1.00
59: 6 [u], 0.96
60: 0 [-], 1.00
61: 0 [-], 1.00
62: 0 [-], 0.53
63: 9 [r], 1.00
64: 0 [-], 1.00
65: 2 [i], 1.00
66: 0 [-], 1.00
67: 0 [-], 1.00
68: 0 [-], 1.00
69: 0 [-], 1.00
70: 0 [-], 1.00
71: 0 [-], 0.96
72: 5 [o], 1.00
73: 0 [-], 1.00
74: 0 [-], 1.00
75: 0 [-], 1.00
76: 0 [-], 1.00
77: 0 [-], 1.00
78: 0 [-], 1.00
79: 8 [s], 1.00
80: 0 [-], 1.00
81: 0 [-], 1.00
82: 0 [-], 0.99
83: 2 [i], 1.00
84: 0 [-], 1.00
85: 7 [t], 1.00
86: 0 [-], 1.00
87: 0 [-], 1.00
88: 16 [y], 1.00
89: 0 [-], 1.00
90: 0 [-], 1.00
91: 0 [-], 1.00
92: 0 [-], 1.00
93: 17 [b], 1.00
94: 0 [-], 1.00
95: 3 [e], 1.00
96: 0 [-], 1.00
97: 0 [-], 1.00
98: 0 [-], 1.00
99: 0 [-], 1.00
100: 0 [-], 1.00
101: 8 [s], 1.00
102: 0 [-], 1.00
103: 0 [-], 1.00
104: 0 [-], 1.00
105: 0 [-], 1.00
106: 0 [-], 1.00
107: 0 [-], 1.00
108: 0 [-], 1.00
109: 0 [-], 0.64
110: 2 [i], 1.00
111: 0 [-], 1.00
112: 0 [-], 1.00
113: 13 [d], 1.00
114: 3 [e], 0.85
115: 0 [-], 1.00
116: 10 [m], 1.00
117: 0 [-], 1.00
118: 0 [-], 1.00
119: 3 [e], 1.00
120: 0 [-], 1.00
121: 0 [-], 1.00
122: 0 [-], 1.00
123: 0 [-], 1.00
124: 1 [a], 1.00
125: 0 [-], 1.00
126: 0 [-], 1.00
127: 7 [t], 1.00
128: 0 [-], 1.00
129: 7 [t], 1.00
130: 15 [h], 1.00
131: 0 [-], 0.79
132: 2 [i], 1.00
133: 0 [-], 1.00
134: 0 [-], 1.00
135: 0 [-], 1.00
136: 8 [s], 1.00
137: 0 [-], 1.00
138: 0 [-], 1.00
139: 0 [-], 1.00
140: 0 [-], 1.00
141: 10 [m], 1.00
142: 0 [-], 1.00
143: 0 [-], 1.00
144: 5 [o], 1.00
145: 0 [-], 1.00
146: 0 [-], 1.00
147: 0 [-], 1.00
148: 10 [m], 1.00
149: 0 [-], 1.00
150: 0 [-], 1.00
151: 3 [e], 1.00
152: 0 [-], 1.00
153: 4 [n], 1.00
154: 0 [-], 1.00
155: 7 [t], 1.00
156: 0 [-], 1.00
157: 0 [-], 1.00
158: 0 [-], 1.00
159: 0 [-], 1.00
160: 0 [-], 1.00
161: 0 [-], 1.00
162: 0 [-], 1.00
163: 0 [-], 1.00
164: 0 [-], 1.00
165: 0 [-], 1.00
166: 0 [-], 1.00
167: 0 [-], 1.00
168: 0 [-], 1.00
注意
對齊以 emission 的幀座標表示,這與原始波形不同。
它包含空白 token 和重複 token。以下是對非空白 token 的解釋。
31: 0 [-], 1.00
32: 2 [i], 1.00 "i" starts and ends
33: 0 [-], 1.00
34: 0 [-], 1.00
35: 15 [h], 1.00 "h" starts
36: 15 [h], 0.93 "h" ends
37: 1 [a], 1.00 "a" starts and ends
38: 0 [-], 0.96
39: 0 [-], 1.00
40: 0 [-], 1.00
41: 13 [d], 1.00 "d" starts and ends
42: 0 [-], 1.00
注意
當相同的 token 出現在空白 token 之後時,它不被視為重複,而是視為新的出現。
a a a b -> a b
a - - b -> a b
a a - b -> a b
a - a b -> a a b
^^^ ^^^
Token 級對齊¶
下一步是解決重複問題,以便每個對齊不依賴於先前的對齊。torchaudio.functional.merge_tokens() 計算 TokenSpan 物件,該物件表示文字記錄中的哪個 token 出現在什麼時間跨度。
token_spans = F.merge_tokens(aligned_tokens, alignment_scores)
print("Token\tTime\tScore")
for s in token_spans:
print(f"{LABELS[s.token]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}")
Token Time Score
i [ 32, 33) 1.00
h [ 35, 37) 0.96
a [ 37, 38) 1.00
d [ 41, 42) 1.00
t [ 44, 45) 1.00
h [ 45, 46) 1.00
a [ 47, 48) 1.00
t [ 50, 51) 1.00
c [ 54, 55) 1.00
u [ 58, 60) 0.98
r [ 63, 64) 1.00
i [ 65, 66) 1.00
o [ 72, 73) 1.00
s [ 79, 80) 1.00
i [ 83, 84) 1.00
t [ 85, 86) 1.00
y [ 88, 89) 1.00
b [ 93, 94) 1.00
e [ 95, 96) 1.00
s [101, 102) 1.00
i [110, 111) 1.00
d [113, 114) 1.00
e [114, 115) 0.85
m [116, 117) 1.00
e [119, 120) 1.00
a [124, 125) 1.00
t [127, 128) 1.00
t [129, 130) 1.00
h [130, 131) 1.00
i [132, 133) 1.00
s [136, 137) 1.00
m [141, 142) 1.00
o [144, 145) 1.00
m [148, 149) 1.00
e [151, 152) 1.00
n [153, 154) 1.00
t [155, 156) 1.00
詞級對齊¶
現在讓我們將 token 級對齊分組到詞級對齊。
def unflatten(list_, lengths):
assert len(list_) == sum(lengths)
i = 0
ret = []
for l in lengths:
ret.append(list_[i : i + l])
i += l
return ret
word_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT])
音訊預覽¶
# Compute average score weighted by the span length
def _score(spans):
return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)
def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / num_frames
x0 = int(ratio * spans[0].start)
x1 = int(ratio * spans[-1].end)
print(f"{transcript} ({_score(spans):.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate)
num_frames = emission.size(1)
# Generate the audio for each segment
print(TRANSCRIPT)
IPython.display.Audio(SPEECH_FILE)
['i', 'had', 'that', 'curiosity', 'beside', 'me', 'at', 'this', 'moment']
preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT[0])
i (1.00): 0.644 - 0.664 sec
preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT[1])
had (0.98): 0.704 - 0.845 sec
preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT[2])
that (1.00): 0.885 - 1.026 sec
preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT[3])
curiosity (1.00): 1.086 - 1.790 sec
preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT[4])
beside (0.97): 1.871 - 2.314 sec
preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT[5])
me (1.00): 2.334 - 2.414 sec
preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT[6])
at (1.00): 2.495 - 2.575 sec
preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT[7])
this (1.00): 2.595 - 2.756 sec
preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT[8])
moment (1.00): 2.837 - 3.138 sec
視覺化¶
現在讓我們看看對齊結果,並將原始語音分割成單詞。
def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / emission.size(1) / sample_rate
fig, axes = plt.subplots(2, 1)
axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
axes[0].set_title("Emission")
axes[0].set_xticks([])
axes[1].specgram(waveform[0], Fs=sample_rate)
for t_spans, chars in zip(token_spans, transcript):
t0, t1 = t_spans[0].start + 0.1, t_spans[-1].end - 0.1
axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white")
axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
for span, char in zip(t_spans, chars):
t0 = span.start * ratio
axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False)
axes[1].set_xlabel("time [second]")
axes[1].set_xlim([0, None])
fig.tight_layout()
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)

blank token 的不一致處理¶
將 token 級對齊拆分成單詞時,您會注意到一些空白 token 的處理方式不同,這使得結果的解釋有些模糊。
當我們繪製得分時,這很容易看出。下圖顯示了詞區域和非詞區域,以及非空白 token 的幀級得分。
def plot_scores(word_spans, scores):
fig, ax = plt.subplots()
span_xs, span_hs = [], []
ax.axvspan(word_spans[0][0].start - 0.05, word_spans[-1][-1].end + 0.05, facecolor="paleturquoise", edgecolor="none", zorder=-1)
for t_span in word_spans:
for span in t_span:
for t in range(span.start, span.end):
span_xs.append(t + 0.5)
span_hs.append(scores[t].item())
ax.annotate(LABELS[span.token], (span.start, -0.07))
ax.axvspan(t_span[0].start - 0.05, t_span[-1].end + 0.05, facecolor="mistyrose", edgecolor="none", zorder=-1)
ax.bar(span_xs, span_hs, color="lightsalmon", edgecolor="coral")
ax.set_title("Frame-level scores and word segments")
ax.set_ylim(-0.1, None)
ax.grid(True, axis="y")
ax.axhline(0, color="black")
fig.tight_layout()
plot_scores(word_spans, alignment_scores)

在此圖中,空白 token 是那些沒有垂直條的高亮區域。您可以看到,有些空白 token 被解釋為單詞的一部分(紅色高亮),而其他空白 token 則不是(藍色高亮)。
造成這種情況的一個原因是模型在訓練時沒有為單詞邊界設定標籤。空白 token 不僅被視為重複,也被視為單詞之間的靜音。
但隨後出現了一個問題。緊隨單詞之後或接近單詞末尾的幀應該是靜音還是重複?
在上面的例子中,如果您回到之前的頻譜圖和詞區域圖,您會看到在“curiosity”中的“y”之後,多個頻率桶中仍然有一些活動。
如果將該幀包含在單詞中,會不會更準確?
不幸的是,CTC 沒有對此提供一個全面的解決方案。用 CTC 訓練的模型已知會表現出“尖峰”響應,也就是說,它們傾向於在標籤出現時達到峰值,但峰值不會持續標籤的整個時長。(注意:預訓練的 Wav2Vec2 模型傾向於在標籤出現開始時達到峰值,但這並非總是如此。)
[Zeyer 等人,2021] 對 CTC 的尖峰行為進行了深入分析。我們鼓勵那些有興趣進一步瞭解的人參考這篇論文。以下是論文中的一段引文,這正是我們在這裡面臨的問題。
尖峰行為在某些情況下可能存在問題, 例如當應用程式要求不使用空白標籤時, 例如為了獲得音素到文字記錄的具有意義的時間準確對齊。
高階用法:使用 <star> token 處理文字記錄¶
現在讓我們看看當文字記錄部分缺失時,如何使用能夠模擬任何 token 的 <star> token 來提高對齊質量。
這裡我們使用上面使用的同一個英語例子。但是我們從文字記錄中刪除了開頭的文字 “i had that curiosity beside me at”。使用這樣的文字記錄對齊音訊會導致現有單詞“this”的錯誤對齊。然而,這個問題可以透過使用 <star> token 來建模缺失的文字來緩解。
首先,我們擴充套件字典以包含 <star> token。
DICTIONARY["*"] = len(DICTIONARY)
接下來,我們擴充套件 emission tensor,增加與 <star> token 對應的額外維度。
star_dim = torch.zeros((1, emission.size(1), 1), device=emission.device, dtype=emission.dtype)
emission = torch.cat((emission, star_dim), 2)
assert len(DICTIONARY) == emission.shape[2]
plot_emission(emission[0])

以下函式結合了所有過程,並一次性從 emission 計算詞段。
def compute_alignments(emission, transcript, dictionary):
tokens = [dictionary[char] for word in transcript for char in word]
alignment, scores = align(emission, tokens)
token_spans = F.merge_tokens(alignment, scores)
word_spans = unflatten(token_spans, [len(word) for word in transcript])
return word_spans
完整文字記錄¶
word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY)
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)

帶 <star> token 的部分文字記錄¶
現在我們將文字記錄的第一部分替換為 <star> token。
transcript = "* this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, emission, transcript)

preview_word(waveform, word_spans[0], num_frames, transcript[0])
* (1.00): 0.000 - 2.595 sec
preview_word(waveform, word_spans[1], num_frames, transcript[1])
this (1.00): 2.595 - 2.756 sec
preview_word(waveform, word_spans[2], num_frames, transcript[2])
moment (1.00): 2.837 - 3.138 sec
不帶 <star> token 的部分文字記錄¶
作為比較,以下示例對齊不使用 <star> token 的部分文字記錄。它演示了 <star> token 在處理刪除錯誤方面的效果。
transcript = "this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, emission, transcript)

結論¶
在本教程中,我們探討了如何使用 torchaudio 的強制對齊 API 來對齊和分割語音檔案,並展示了一個高階用法:在存在轉錄錯誤時,引入 <star> token 如何提高對齊精度。
致謝¶
感謝 Vineel Pratap 和 Zhaoheng Ni 開發並開源了強制對齊 API。
指令碼總執行時間: ( 0 分鐘 6.927 秒)