自定义Fbank训练

2021-05-19  本文已影响0人  静一下1

def feature_wav(wav_file, pre_emphasis=0.97, n_filter=40, frame_len_s=0.032, frame_shift_s=0.01):

    import numpy as np

    from scipy.io import wavfile

    from scipy.fftpack import dct

    import matplotlib.pyplot as plt

    #读取语音数据

    fs, sig = wavfile.read(wav_file)

    #对语音能量归一化

    Ena = sum(sig ** 2)

    sig = sig / np.sqrt(Ena)

    #预加重

    sig = np.append(sig[0], sig[1:] - pre_emphasis * sig[:-1])

    #分帧加窗

    def framing(frame_len_s, frame_shift_s, fs, sig):

        sig_n = len(sig)

        frame_len_n, frame_shift_n = int(

            round(fs * frame_len_s)), int(round(fs * frame_shift_s))

        num_frame = int(

            np.ceil(float(sig_n - frame_len_n) / frame_shift_n) + 1)

        pad_num = frame_shift_n * (num_frame - 1) + \

            frame_len_n - sig_n 

        pad_zero = np.zeros(int(pad_num))    # ????0

        pad_sig = np.append(sig, pad_zero)

        frame_inner_index = np.arange(0, frame_len_n)

        frame_index = np.arange(0, num_frame) * frame_shift_n

        frame_inner_index_extend = np.tile(frame_inner_index, (num_frame, 1))

        frame_index_extend = np.expand_dims(frame_index, 1)

        each_frame_index = frame_inner_index_extend + frame_index_extend

        each_frame_index = each_frame_index.astype(np.int, copy=False)

        frame_sig = pad_sig[each_frame_index]

        return frame_sig

#    frame_len_s = 0.025

#    frame_shift_s = 0.01

    frame_sig = framing(frame_len_s, frame_shift_s, fs, sig)

    window = np.hamming(int(round(frame_len_s * fs)))

    frame_sig *= window

    #快速傅里叶变换

    def stft(frame_sig, nfft=512):

        frame_spec = np.fft.rfft(frame_sig, nfft)

        frame_mag = np.abs(frame_spec)

        frame_pow = (frame_mag ** 2) * 1.0 / nfft

        return frame_pow

    nfft = 512

    frame_pow = stft(frame_sig, nfft)

    #梅尔滤波 0全部替换成无穷小 结果以10为底取对数,并乘20

    def mel_filter(frame_pow, fs, n_filter, nfft):

        mel_min = 0   

        mel_max = 2595 * np.log10(1 + fs / 2.0 / 700) 

        mel_points = np.linspace(mel_min, mel_max, n_filter + 2)

        hz_points = 700 * (10 ** (mel_points / 2595.0) - 1)    

        filter_edge = np.floor(hz_points * (nfft + 1) / fs)   

        fbank = np.zeros((n_filter, int(nfft / 2 + 1)))

        for m in range(1, 1 + n_filter):

            f_left = int(filter_edge[m - 1])    

            f_center = int(filter_edge[m])      

            f_right = int(filter_edge[m + 1])   

            for k in range(f_left, f_center):

                fbank[m - 1, k] = (k - f_left) / (f_center - f_left)

            for k in range(f_center, f_right):

                fbank[m - 1, k] = (f_right - k) / (f_right - f_center)

        filter_banks = np.dot(frame_pow, fbank.T)

        filter_banks = np.where(

            filter_banks == 0, np.finfo(float).eps, filter_banks)

        filter_banks = 20 * np.log10(filter_banks)  # dB

        return filter_banks

    filter_banks = mel_filter(frame_pow, fs, n_filter, nfft)

    return filter_banks

在特征提取部分将ta.compliance.kaldi.fbank改为自定义的特征提取函数

由于输出结果是numpy结构的 需要进行转换变为FloatTensor形式

设置了一个输出标志 ,如果这个函数被调用了则输出:DIYFbank

同样在标准化这里也设置个标志:norm

spec-augment:频率掩蔽+时间掩蔽,忽略时间扭曲

设置标签:mask

训练过程:

三部分都会调用

测试阶段:

只会用到自定义的Fbank和normalization两个部分

词错率为8.401%,与8.276%相比变化增加了0.125%,用自定义Fbank是可行的!

上一篇下一篇

猜你喜欢

热点阅读