sgd+α: 確率的勾配降下法の現在と未来

Post on 08-Sep-2014

10.091 Views

Category:

Technology

7 Downloads

Preview:

Click to see full reader

DESCRIPTION

SGDの最新の拡張手法を紹介: Importance-aware UpdateやNormalized Online Learningなど.SGD+αはここまで出来る!

TRANSCRIPT

東京大学 情報理工学系研究科大岩 秀和 / @kisa12012

SGD+α確率的勾配降下法の現在と未来

2013/10/17 PFIセミナー

自己紹介

•大岩 秀和 (a.k.a. @kisa12012)•所属: 東大 数理情報 D2 (中川研)•研究: 機械学習・言語処理•オンライン学習/確率的最適化/スパース正則化 etc...

•前回のセミナー: 能動学習入門•PFI: インターン(10) -> アルバイト(-12)

2

今日の話 1/2

•みんな大好き(?)確率的勾配降下法•Stochastic Gradient Descent (SGD)•オンライン学習の文脈では,Online Gradient Decent (OGD)と呼ばれる

•SGDは便利だけど,使いにくい所も•ステップ幅の設定方法とか

3

今日の話 2/2

•SGDに+αで出来る拡張の話をします•最近提案されたトッピング(研究)を紹介•ステップ幅設定/自動正規化など

4

Plain SGD Topping

1. Plain SGD

5

基本問題設定

6

•(損失)最小化問題•値が最小となる を求めたい• は凸関数•関数がN個の関数に分解可能•必須ではないですが,今回はこの条件で進めます

f(·)

minw

f(w)

w⇤

f(w)

w⇤

f(w) =NX

n=1

ft(w)

Plain SGD

7

x1 x2 ................................ xN

wt

Plain SGD

8

x1 x2 ................................ xN

データを一つランダムにピックアップ

wt

Plain SGD

9

x1 x2 ................................ xN

選んだデータに対応する勾配でパラメータ更新

wt+1 = wt � ⌘t@f2(wt)

Plain SGD

10

x1 x2 ................................ xN

用いる損失関数は様々

wt+1 = wt � ⌘t@f2(wt)二乗損失 (回帰)

(wTx� y)2 ヒンジ損失 (分類)

max(0, 1� ywTx)

Plain SGD

11

•関数を一つだけサンプルして,勾配を計算

•関数 の値が一番小さくなる 方向へパラメータを更新

• でステップの幅を調整•微分不可能な場合も劣勾配で

fnt(·)

⌘t

wt+1 = wt � ⌘trfnt(wt)

Pros and Cons of Plain SGD

•長所•大規模データに有効 (Bottou+ 11)•”そこそこの”解がすぐに欲しい時•実装・デバッグ・実験サイクルを回すのが楽•ノウハウ集 (Bottou 12)•最適解への収束証明あり

12

wt+1 = wt � ⌘trfnt(wt)

Pros and Cons of Plain SGD

•短所•ステップ幅で収束性が大きく変化•Overshoot, Undershoot•前処理しないと性能が劇的に悪化•正規化, TF-IDF•厳密な最適解が欲しい場合は遅い

13

損失(対数)

時間

SGD

GD

SGD+α•時代はビッグデータ•複雑な最適化よりシンプルで軽いSGD•しかし,SGDも不便な部分が多い•SGD+α•+αで,より効果的なアルゴリズムへ•+αで,欠点の少ないアルゴリズムへ•「それ,実はSGD+αで出来るよ?」

14

今日紹介する+α• Importance-aware Update (Karampatziakis+ 11)•ステップ幅の問題を緩和•Normalized Online Learning (Stéphane+ 13)•前処理なし,オンラインで特徴量の正規化•Linear Convergence SGD (Le Roux+ 12)•バッチデータに対して,線形収束するSGD•他にもAdaGrad/省メモリ化等を紹介したかったですが,略

15

2. Importance-aware Update

16

Overshoot / Undershoot

17

ステップ幅が大きすぎる 小さすぎる

SGDはステップ幅設定に失敗すると,劇的に悪化

ステップ幅設定は大変

•Overshootで生じるnan/infの嵐•Cross-Validationで最適ステップ幅探しの旅•つらい•ステップ幅選択に悩みたくない•Importance-aware Update•キーワード: Invariance, Safety

18

w = (inf, inf, . . . )

Invariance•ステップ幅設定をh倍 -> データ1個分の更新h回へ再設定

19

Importance-aware Update(Karampatziakis+ 11)

•Invarianceを満たすステップ幅の再設定法•線形予測器では変化するのはステップ幅のみ•主な損失関数のステップ幅は,閉じた式で計算可能

•L2正則化等が入っても大丈夫•Regret Boundの証明あり

20

Importance-aware step width

21

Table 1: Importance Weight Aware Updates for Various Loss FunctionsLoss `(p, y) Update s(h)

Squared (y � p)2 p�y

x

>x

⇣1� e�h⌘x

>x

Logistic log(1 + e�yp) W (e

h⌘x

>x+yp+e

yp

)�h⌘x

>x�e

yp

yx

>x

for y 2 {�1, 1}Exponential e�yp

py�log(h⌘x

>x+e

py

)

x

>xy

for y 2 {�1, 1}

Logarithmic y log y

p

+ (1� y) log 1�y

1�p

if y = 0p�1+

p(p�1)

2+2h⌘x

>x

x

>x

if y = 1p�

pp

2+2h⌘x

>x

x

>x

Hellinger (pp�p

y)2 � (p1� p�

p1� y)2

if y = 0p�1+

14 (12h⌘x

>x+8(1�p)

3/2)

2/3

x

>x

if y = 1p� 1

4 (12h⌘x>x+8p

3/2)

2/3

x

>x

Hinge max(0, 1� yp) �ymin�h⌘, 1�yp

x

>x

�for y 2 {�1, 1}

⌧ -Quantileif y > p ⌧(y � p)if y p (1� ⌧)(p� y)

if y > p �⌧ min(h⌘, y�p

⌧x

>x

)if y p (1� ⌧)min(h⌘, p�y

(1�⌧)x

>x

)

(6) gives a di↵erential equation whose solution is theresult of a continuous gradient descent process.

As a sanity check we rederive (5) using (6). Forsquared loss @`

@p

= p� y and we get a linear ODE:

s0(h) = ⌘((wt

� s(h)x)>x� y), s(0) = 0

whose solution is exactly (5).

3.1 Other Loss Functions

Using (6) as our framework, we can derive step sizesfor many popular loss function as summarized in ta-ble 1. And even when it is not possible to have a closedform solution, one could approximate s(h) o✏ine usingnumerical integration.

For the logistic loss, the solution involves the LambertW function: W (z)eW (z) = z, and the solution can be

verified using W 0(z) = W (z)

z(1+W (z))

. The exponentialloss also fits nicely into our framework.

For the logarithmic loss the di↵erential equation has noexplicit form for all y 2 [0, 1]. The table presents thecommon case y 2 {0, 1}. In this case each value of ygives rise to a di↵erential equation whose solution hasan explicit form. Note that here the solutions to thedi↵erential equation satisfy a second degree equationand hence each branch has two solutions. We haveselected the one that pushes the prediction towardssmaller losses. Also, when optimizing log loss witha linear model the predictions should be clipped intoan interval contained in (0, 1) to avoid attaining aninfinite loss. The expression in table 1 can be used tofind the smallest importance weight h0 that would hitthe clipping point. Then the update should use theminimum of h and h0.

A similar situation arises for the Hellinger loss. The

solution to (6) has no simple form for all y 2 [0, 1] butfor y 2 {0, 1} we get the expressions in table 1.

3.1.1 Hinge Loss and Quantile Loss

Two other commonly used loss function are the hingeloss and the ⌧ -quantile loss where ⌧ 2 [0, 1] is a pa-rameter function. These are di↵erentiable everywhereexcept at one point and at that point the subdi↵eren-tial contains zero.

Hence, for the hinge loss, a valid expression for (6) is

s0(h) =

(�⌘y y(w � s(h)x)>x < 1

0 y(w � s(h)x)>x � 1

The first branch (together with s(0) = 0) gives s(h) =�yh⌘ for y(w + yh⌘x)>x < 1. Otherwise, i.e. when

h � hhinge

= 1�yw

>x

⌘x

>x

, s(h) is a constant. Here hhinge

isthe importance weight that would make the updatedprediction lie at the hinge. To maintain continuity athhinge

we set s(h) = �yhhinge

⌘. In conclusion

s(h) = �ymin(h, hhinge

)⌘

This matches the intuition when one thinks aboutthe limit of infinitely many infinitely small updates:If the importance weight is large enough the processwill bring the prediction up to y and make no furtherprogress. And for a small enough importance weight,the hingle loss looks like a linear function.

The quantile loss is similar and the update rule firstcomputes the importance weight h0 that would takethe updated prediction at the point of nondi↵erentia-bility and then multiplies the gradient by min(h, h0).

3.2 Variable Learning Rate

To handle a decaying learning rate ⌘t

, we just need toslightly modify (6). Let ⌘

t

(u) be the value of the learn-ing rate u timesteps after time t. Then our di↵erential

(Karampatziakis+ 11) より

ステップ幅の再設定式

Safety

•Importance-aware Updateとなった二乗損失やヒンジ損失は,Safetyの性質を持つ

22

領域を超えない

w

Tt+1x� y

w

Tt x� y

� 0

が必ず満たされる

Safety

No more step width war!

•SafetyによりOvershootの危険性が減る•初期ステップ幅を大きめにとれる•ステップ幅の精密化により,精度も改善•賢いステップ幅選択方法は他にも提案•(Duchi+ 10), (Schaul+ 13)...

23

3. Normalized Online Learning

24

特徴量の正規化

25

•各特徴量のスケールに強い影響を受ける•スケールの上限/下限の差が大きいほど,理論的にも実証的にも性能悪化

•バッチ学習の場合は前処理で正規化する場合がほとんど•オンライン学習では,前処理が不可能な場合がある•全部のデータを前もって用意出来ない etc.

x = (1.0, 5.2, . . . )

x = (1000.0, 5.2, . . . ) x = (0.001, 5.2, . . . )

Normalized Online Learning(Stéphane+ 13)

26

................................

各特徴量に,最大値保存用のボックスを設置wt = (1.0, 2.0, . . . , 5.0)

s1 s2 sD

Normalized Online Learning

27

データを一つランダムにピックアップwt = (1.0, 2.0, . . . , 5.0)

x2 = (2.0, 1.0, . . . , 5.0)

................................s1 s2 sD

Normalized Online Learning

28

wt = (1.0, 2.0, . . . , 5.0)

x2 = (2.0, 1.0, . . . , 5.0)

選択したデータの各特徴量の値が最大値を超えていないかチェック

................................s1 s2 sD

Normalized Online Learning

29

x2 = (2.0, 1.0, . . . , 5.0)

もし超えていたら,正規化せずに過去データを処理してしまった分,重みを補正

2.0 ................................s2 sD

If 2.0 > s1

wt = (1.0⇥ s212.02

, 2.0, . . . , 5.0)

Normalized Online Learning

30

x2 = (2.0, 1.0, . . . , 5.0)

あとは,サンプルしてきたデータを使って,正規化しながら確率的勾配法でアップデート

2.0 ................................s2 sD

wt+1 = wt � ⌘tg (@f2(wt), s1:D)

Normalized Online Learning

•オンライン処理しながら自動で正規化•スケールを(あまり)気にせず,SGDを回せるように!•スケールも敵対的に設定されるRegret Boundの証明付き

31

Algorithm 1 NG(learning rate ⌘t

)

1. Initially wi

= 0, si

= 0, N = 0

2. For each timestep t observe example (x, y)

(a) For each i, if |xi

| > si

i. wi

wis2i

|xi|2

ii. si

|xi

|(b) y =

Pi

wi

xi

(c) N N +

Pi

x

2i

s

2i

(d) For each i,i. w

i

wi

� ⌘t

t

N

1

s

2i

@L(y,y)

@wi

tic gradient descent.

The vector element si

stores the magnitude of feature i ac-cording to s

ti

= max

t

02{1...t} |xt

0i

|. These are updated andmaintained online in steps 2.(a).ii, and used to rescale theupdate on a per-feature basis in step 2.(d).i.

Using N makes the learning rate (rather than feature scale)control the average change in prediction from an update.Here N/t is the average change in the prediction excluding⌘, so multiplying by 1/(N/t) = t/N causes the averagechange in the prediction to be entirely controlled by ⌘.

Step 2.(a).i squashes a weight i when a new scale is en-countered. Neglecting the impact of N , the new value isprecisely equal to what the weight’s value would have beenif all previous updates used the new scale.

Many other online learning algorithms can bemade scale invariant using variants of this ap-proach. One attractive choice is adaptive gradientdescent [McMahan and Streeter, 2010, Duchi et al., 2011]since this also has per-feature learning rates. The nor-malized version of adaptive gradient descent is given inalgorithm 2.

In order to use this, the algorithm must maintain the sum of

gradients squared Gi

=

P(x,y) observed

⇣@L(y,y)

@wi

⌘2

forfeature i in step 2.d.i. The interaction between N and G issomewhat tricky, because a large average update (i.e. mostfeatures have a magnitude near their scale) increases thevalue of G

i

as well as N implying the power on N must bedecreased to compensate. Similarly, we reduce the poweron s

i

and |xi

| to 1 throughout. The more complex updaterule is scale invariant and the dependence on N introducesan automatic global rescaling of the update rule.

In the next sections we analyze and justify this algorithm.We demonstrate that NAG competes well against a set ofpredictors w with predictions (w>x) bounded by some con-stant over all the inputs x

t

seen during training. In practice,

Algorithm 2 NAG(learning rate ⌘)

1. Initially wi

= 0, si

= 0, Gi

= 0, N = 0

2. For each timestep t observe example (x, y)

(a) For each i, if |xi

| > si

i. wi

wisi|xi|

ii. si

|xi

|(b) y =

Pi

wi

xi

(c) N N +

Pi

x

2i

s

2i

(d) For each i,

i. Gi

Gi

+

⇣@L(y,y)

@wi

⌘2

ii. wi

wi

� ⌘q

t

N

1

sipGi

@L(y,y)

@wi

as this is potentially sensitive to outliers, we also consider asquared norm version of NAG, which we refer to as sNAGthat is a straightforward modification—we simply keep theaccumulator s

i

=

Px2

i

and usep

si

/t in the update rule.That is, normalization is carried using the standard devia-tion (more precisely, the square root of the second moment)of each feature, rather than the max norm. With respect toour analysis below, this simple modification can be inter-preted as changing slightly the set of predictors we com-pete against, i.e. predictors with predictions bounded by aconstant only over the inputs within 1 standard deviation.Intuitively, this is more robust and appropriate in the pres-ence of outliers. While our analysis focuses on NAG, inpractice, sNAG sometimes yield improved performance.

4 The Scaling Adversary Setting

In common machine learning practice, the choice of unitsfor any particular feature is arbitrary. For example, whenestimating the value of a house, the land associated with ahouse may be encoded either in acres or square feet. Tomodel this effect, we propose a scaling adversary, which ismore powerful than the standard adversary in adversarialonline learning settings.

The setting for analysis is similar to adversarial online lin-ear learning, with the primary difference in the goal. Thesetting proceeds in the following round-by-round fashionwhere

1. Prior to all rounds, the adversary commits to a fixedpositive-definite matrix S. This is not revealed to thelearner.

2. On each round t,

(a) The adversary chooses a vector xt

such that||S1/2x

t

||1 1, where S1/2 is the principal

(Stéphane+ 13)より

4. Linear Convergence SGD

32

線形収束するSGD

33

•Plain SGDの収束速度•一般的な条件の下で凸関数 •滑らかで強凸•使用データが予め固定されている場合•SGD+αで線形収束が可能に•厳密な最適解を得たい場合もSGD+α

O(1/pT )

O(1/T )

f(w)� f(w⇤)

O(cT )

Stochastic Average Gradient(Le Roux+ 12)

34

x1 x2 ................................ xN

wt

Stochastic Average Gradient

35

x1 x2 ................................ xN

................................

各データに,勾配保存用のボックスを一つ用意

@fN (·)@f1(·) @f2(·)

wt

Stochastic Average Gradient

36

x1 x2 ................................ xN

................................

データを一つランダムにピックアップ

@fN (·)@f1(·) @f2(·)

wt

Stochastic Average Gradient

37

x1 x2 ................................ xN

................................

選んだデータに対応する勾配情報を更新

@fN (·)@f1(·)

wt

@f2(wold

) 昔の勾配はステル

@f2(wt)

Stochastic Average Gradient

38

x1 x2 ................................ xN

................................

全勾配情報を使って,重みベクトルを更新

@fN (·)@f1(·)

wt+1 = wt �⌘tN

NX

n=1

@fn(·)

@f2(wt)

新しい勾配もあれば 古い勾配もある

線形収束するSGD

• が強凸かつ各 が滑らかな時,線形収束•線形予測器ならば,一データにつきスカラー(float/double)を一つ持てば良い

•正則化項を加えたい場合•SAGでは,L1を使ったスパース化の収束性は未証明 (近接勾配法)

•SDCA [Shalev+ 13], MISO[Mairal 13]39

fn(·)f

まとめ•SGD+α•ステップ幅設定/自動正規化/線形収束化•その他,特徴適応型のステップ幅調整/省メモリ化等,SGD拡張はまだまだ終わらない

•フルスタックなSGDピザが出来る..?•近いうちに,ソルバーの裏側でよしなに動いてくれる..はず?

•そんなソルバーを募集中40

• L. Bottou, O.Bousquet, “The Tradeoffs of Large-Scale Learning”, Optimization for Machine Learning, 2011.

• L. Bottou, “Stochastic Gradient Descent Tricks”, Neural Networks, 2012.

• Nikos Karampatziakis, John Langford, "Online Importance Weight Aware Updates", UAI, 2011.

• John C. Duchi, Elad Hazan, Yoram Singer, "Adaptive Subgradient Methods for Online Learning and Stochastic Optimization", JMLR, 2011.

• Tom Schaul, Sixin Zhang and Yann LeCun., "No more Pesky Learning Rates", ICML, 2013.

• Stéphane Ross, Paul Mineiro, John Langford, "Normalized Online Learning", UAI, 2013.

• Nicolas Le Roux, Mark Schmidt, Francis Bach, “Stochastic Gradient Method with an Exponential Convergence Rate for Finite Training Sets”, NIPS, 2012.

• Shai Shalev-Shwartz, Tong Zhang, “Stochastic Dual Coordinate Ascent Methods for Regularized Loss Minimization”, JMLR, 2013.

• Julien Mairal, “Optimization with First-Order Surrogate Functions”, ICML, 2013.

参考文献

41

top related