AI人工智能大数据学习+数据库知识

研究线性模型训练中损失变化的规律和最优学习率的影响

2024-11-01  本文已影响0人  久别重逢已经那边v发

探究一维线性模型训练中,测试损失随训练步数变化的缩放定律及其最优学习率影响,并研究多维线性模型训练的缩放定律,确定参数以符合特定损失衰减模式。

研究大模型的缩放定律对减少其训练开销至关重要,即最终的测试损失如何随着训练步数和模型大小的变化而变化?本题中,我们研究了训练线性模型时的缩放定律。

  1. 在本小问中,考虑使用梯度下降学习一个一维线性模型的情况。

设学习率\eta\in(0,\frac{1}{3}],那么T≥0步迭代之后的测试损失的期望

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{w_T}\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(f_{w_T}(x)-y)^2]

是多少?

  1. 现在我们在第一小问的设定下,考虑学习率\eta被调到最优的情况,求函数g(T),使得当T\rightarrow+\infty时,以下条件成立:

\left|\underset{η\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O(\frac{(\log T)^2}{T^2})

  1. 一个常常被观测到的实验现象是大语言模型的预训练过程大致遵循Chinchilla缩放定律:

\overline{\mathcal{L}}_{N,T}≈\frac{A}{N^\alpha}+\frac{B}{T^\beta}+C

其中\overline{\mathcal{L}}_{N,T}是在经过T步训练后具有N个参数的模型的测试损失的期望,ABaβC是常数。现在我们举一个训练多维线性模型的例子,使其也遵循类似的缩放定律。

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{\mathbf{w}_T}\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(f_{\mathbf{w}_T}(x)-y)^2]为以学习率\eta\in(0,\frac{1}{3}]对其有N个参数的线性模型进行T≥0步训练后的测试损失的期望。

请求出αβC,使得\forall\gamma>0,\forall c>0,当T=N^{c+o(1)}N足够大时,以下条件成立:

\epsilon(N,T):=\frac{\inf_{\eta\in(0,\frac{1}{3}]}{\overline{\mathcal{L}}_{N,T}}-C}{\frac{A}{N^\alpha}+\frac{B}{T^\beta}}

(\log N+\log T)^{-γ}\leq \epsilon(N,T)\leq(\log N+\log T)^γ。即\inf_{\eta\in(0,\frac{1}{3}]}{\overline{\mathcal{L}}_{N,T}}=\tilde{\Theta}(N^{-\alpha}+T^{-\beta})+C,其中\tilde{\Theta}表示忽略任何关于\log N\log T的多项式。

解:

  1. 首先,我们来计算测试损失的期望\overline{\mathcal{L}}_{\eta,T}

由于xy是独立的随机变量,且y的条件分布是N(3x, 1),我们可以写出测试损失的期望为:

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(w_T x - y)^2]

由于y=3x+\epsilon,其中\epsilon\sim N(0, 1)且独立于x,我们可以将y替换为3x+\epsilon

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{x,\epsilon}[\frac{1}{2}(w_T x - (3x+\epsilon))^2]

展开并利用\mathbb{E}[\epsilon^2]=1\mathbb{E}[x^2]=1(因为x\sim N(0, 1)):

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_x[\frac{1}{2}(w_T^2 x^2 - 6w_T x^2 + 9x^2 + \epsilon^2 - 6w_T x \epsilon + 3w_T^2 x^2)]

由于\epsilonx是独立的,我们可以分别计算期望:

\overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(w_T^2 - 6w_T + 9)\mathbb{E}[x^2] + \frac{1}{2}\mathbb{E}[\epsilon^2]

\overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(w_T^2 - 6w_T + 9) + \frac{1}{2}

现在我们需要计算w_T的期望值。由于w_t的更新规则是w_{t+1}=w_t-\eta\nabla l_t(w_t),我们有:

\nabla l_t(w_t) = w_t x_t - y_t = w_t x_t - (3x_t + \epsilon)

因此,更新规则变为:

w_{t+1} = w_t - \eta(w_t x_t - 3x_t - \epsilon)

取期望并利用\mathbb{E}[x_t]=0\mathbb{E}[\epsilon]=0

\mathbb{E}[w_{t+1}] = \mathbb{E}[w_t] - \eta(3\mathbb{E}[x_t^2])

由于x_t^2的期望是1,我们有:

\mathbb{E}[w_{t+1}] = \mathbb{E}[w_t] - 3\eta

由于w_0=0,我们可以递归地计算w_T

\mathbb{E}[w_T] = -3\eta T

\mathbb{E}[w_T]代入测试损失的期望中:

\overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}((-3\eta T)^2 - 6(-3\eta T) + 9) + \frac{1}{2}

\overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(9\eta^2 T^2 + 18\eta T + 9) + \frac{1}{2}

\overline{\mathcal{L}}_{\eta,T}=\frac{9\eta^2 T^2 + 18\eta T + 10}{2}

  1. 接下来,我们需要找到g(T)

首先,我们需要最小化\overline{\mathcal{L}}_{\eta,T}关于\eta。我们可以通过设置\frac{d\overline{\mathcal{L}}_{\eta,T}}{d\eta}=0来找到最优的学习率\eta^*

\frac{d}{d\eta}(\frac{9\eta^2 T^2 + 18\eta T + 10}{2})=9\eta T^2 + 18T=0

解得:

\eta^* = \frac{2}{3T}

\eta^*代入\overline{\mathcal{L}}_{\eta,T}中,我们得到最小化测试损失的表达式:

\overline{\mathcal{L}}_{\eta^*,T}=\frac{9(\frac{2}{3T})^2 T^2 + 18(\frac{2}{3T}) T + 10}{2}

\overline{\mathcal{L}}_{\eta^*,T}=\frac{9(\frac{4}{9T^2}) T^2 + 18(\frac{2}{3T}) T + 10}{2}

\overline{\mathcal{L}}_{\eta^*,T}=\frac{4 + 12 + 10}{2}

\overline{\mathcal{L}}_{\eta^*,T}=\frac{26}{2}

\overline{\mathcal{L}}_{\eta^*,T}=13

现在,我们需要找到g(T),使得当T\rightarrow+\infty时,以下条件成立:

\left|\underset{\eta\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O\left(\frac{(\log T)^2}{T^2}\right)

由于我们已经找到了最优的学习率\eta^*,我们可以将\overline{\mathcal{L}}_{\eta^*,T}视为\mathcal{I}_{n,T}的下界。因此,我们需要找到一个函数g(T),使得当T趋向于无穷大时,\overline{\mathcal{L}}_{\eta^*,T}g(T)之间的差异满足上述条件。

考虑到\overline{\mathcal{L}}_{\eta^*,T}是一个常数13,我们可以推断g(T)应该也是一个常数,因为测试损失的期望在最优学习率下不随T变化。因此,我们可以选择g(T)=13

现在,我们需要验证这个选择是否满足条件:

\left|\underset{\eta\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O\left(\frac{(\log T)^2}{T^2}\right)

由于\mathcal{I}_{n,T}的最小值是13,我们有:

\left|13-13\right|=0

显然,0=O\left(\frac{(\log T)^2}{T^2}\right),因为当T趋向于无穷大时,\frac{(\log T)^2}{T^2}趋向于0。因此,我们的选择g(T)=13是正确的。

综上所述,g(T)=13满足题目中的条件。

3.为了解决这个问题,我们需要推导出多维线性模型在给定数据分布下的缩放定律。根据题目描述,我们有一个线性模型,其参数遵循特定的缩放定律。我们将通过以下步骤来解决这个问题:

步骤 1: 理解数据分布

数据分布 \mathcal{D} 是通过 Zipf 分布来选择输入向量的非零坐标,然后根据该坐标的值来生成输出 y。这意味着大部分的数据集中在较少的非零坐标上。

步骤 2: 定义损失函数

损失函数 \overline{\mathcal{L}}_{\eta,T} 是在给定学习率 \eta 和训练步数 T 后,模型参数 \mathbf{w} 的测试损失的期望。

步骤 3: 推导缩放定律

我们需要找到 \alpha\beta,和 C 使得损失函数符合 \overline{\mathcal{L}}_{N,T}≈\frac{A}{N^\alpha}+\frac{B}{T^\beta}+C 的形式。

对于 \alpha 的推导:

对于 \beta 的推导:

对于 C 的推导:

步骤 4: 确定 \alpha\beta,和 C

为了确定 \alpha\beta,和 C,我们需要进行以下分析:

步骤 5: 验证条件

我们需要验证 \epsilon(N,T) 的条件是否成立。这通常涉及到对 \overline{\mathcal{L}}_{N,T} 进行详细的分析,并证明它符合给定的缩放形式。这通常需要数学上的证明和/或实验验证。

综上所述,我们可以假设 \alpha = \frac{1}{b}\beta = \frac{1}{2}C 是一个正数。然而,为了得到精确的值,我们需要更深入的分析和实验数据。在实际应用中,这些参数通常是通过实验来确定的。

上一篇 下一篇

猜你喜欢

热点阅读