모바일 앱 환경에서는 latex 수식이 깨져 나타나므로 가급적 웹 환경에서 봐주시길 바랍니다.
다음과 같은 Finite sum 꼴의 optimization problem을 정의하자.
$\min f(x) := \sum_{i=1}^n f_i(x)$
여기서 $i$는 data index를 의미한다.
먼저 L-smooth의 descent lemma에서 시작하자.
함수 $f$가 L-smooth함은 다음을 의미한다.
$\lVert \nabla f(x) - \nabla f(y) \rVert \le L \lVert x - y \rVert$
모든 $x, y$에 대해서 $f$의 Gradient가 L-Lipshitz continuous할 때 우리는 함수 $f$가 L-smooth하다고 한다.
그리고 함수 $f$가 2번 미분 가능하다면 L-smooth는 다음과 필요충분 조건 관계이다.
$\nabla^2 f(x) \le L I$
이는 결국 $f$의 Hessian의 모든 singular value들이 L보다 작다는 의미이다.
이를 활용해 우리는 다음과 같은 Descent lemma를 유도할 수 있다.
$f(y) \le f(x) + \langle \nabla f(x), y - x \rangle + \frac{L}{2} \lVert y - x \rVert^2$
여기서 현재 우리는 SGD에 focus를 맞추고 있으므로 $x_{t+1} = x_t - \eta_t \nabla f_i(x_t)$를 집어넣자.
$f(x_{t+1}) \le f(x_t) - \eta_t \langle \nabla f(x_t), \nabla f_i(x_t) \rangle + \eta_t^2 \lVert \nabla f_i(x_t) \rVert^2$
위 식은 한 번의 SGD iteration을 통해서 우리가 함숫값을 감소시켜나가기 위해선 어떠한 조건들이 필요한가를 잘 보여준다.
step size $\eta_t$는 항상 0보다 큰 real number이므로 우선 $\langle \nabla f(x_t), \nabla f_i(x_t) \rangle$이 0보다 커야 한다. 다음은 $\lVert \nabla f_i(x_t) \rVert^2$가 너무 커서는 안 된다. 만약 이 값이 감소하는 값보다 더욱 커져버린다면 함숫값이 오히려 증가할 수도 있게 된다.
즉, 우리는 convergence를 위해서는 파라미터를 업데이트 하는 stochastic gradient와 True gradient가 align되어져야 하고 (inner product값이 0보다 커야 하고), stochastic gradient의 second moment는 매우 커서는 안 된다. (즉 upper bound되어져 있어야 한다.)
여기서 우리는 stochastic gradient의 second moment의 upper bound 가정이 stochastic optimization algorithm의 convergence rate에서 매우 중요함을 알 수 있다.
그리고 실제로 이 가정에 대한 연구는 지난 10년 동안 상당히 활발하게 이뤄져 왔다.
그리고 가장 최근에 나온 가정은 다음과 같다.
$\lVert \nabla f_i(x_t) \rVert^2 \le 2A (f(x_t) - f^{inf}) + B (\lVert \nabla f(x_t) \rVert) + C$
이 가정은 다음의 논문에서 소개되었다.
Better theory for SGD in the non-convex world. TMLR 2023. Khaned et al.
정리하자면 L-smooth한 $f$를 (convex or non-convex or strongly-convex) 어떤 stochastic gradient를 사용하는 알고리즘으로 최소화해나갈 때 convergence rate을 구하기 위해서는 다음이 필요하다.
$1$. \mathbb{E} \langle \nabla f(x_t), g_t \rangle \ge 0$
이는 우리가 파라미터를 업데이트할 때 사용되는 벡터 ($g_t$)가 True gradient $\nabla f(x_t)$와 잘 align된다는 의미이다.
즉 True graient (정확히는 Negative gradient)는 가장 가파른 하강 방향을 가리키고 이 방향과 너무 멀리 떨어지면 (즉, inner product가 0보다 작으면) 오히려 함숫값이 증가할 수도 있으므로 converge하지 않을 수도 있다.
더 나아가, 얼마만큼의 속도로 감소해나가는지 즉 convergence rate을 구하기 위해서는 이 값의 lower bound를 유도해야 한다.
$2$.\lVert g_t \rVert^2 \le$ something
이는 우리가 파라미터를 업데이트할 때 사용되는 벡터의 second moment가 너무 큰 것을 방지해준다.
즉 매 step마다 벡터의 방향이 너무 심하게 변동할 수도 있는 것을 방지해주고, 함숫값이 증가하는 방향으로의 벡터가 나올 가능성도 줄여준다.
위 모든 내용은 Descent lemma를 통해서 살펴보면 당연한 사실이다.
'Deep dive into Optimization' 카테고리의 다른 글
Optimization 심화 : Random Reshuffling (2) (1) | 2023.09.04 |
---|---|
Optimization 심화 : Random Reshuffling (1) (0) | 2023.08.31 |
Optimization 심화: SGD (2) (0) | 2023.08.20 |
Optimization 심화 1 : SGD (1) (1) | 2023.08.14 |
Deep dive into Optimization : Proximal gradient descent (2) (0) | 2023.06.14 |
댓글