recent applications of approximate ... - statistics researchstats.research.att.com › nycseminars...

Post on 25-Jun-2020

1 Views

Category:

Documents

0 Downloads

Preview:

Click to see full reader

TRANSCRIPT

Recent Applications of Approximate MessagePassing Algorithms for High-dimensional

Statistical Estimation

Cynthia Rush, Columbia University

Joint work with

Ramji Venkataramanan (University of Cambridge)

February 12, 2018

1 / 31

High-dimensional Linear Regression

= yβ0A + wm

N

Want to reconstruct β0 from y = Aβ0 + w

• y : measurement vector in Rm

• w : measurement noise in Rm

• A: m × N design matrix

• Number of measurements m < N

• β0 has k < N non-zero elements, i.e. it is k-sparse

2 / 31

High-dimensional Linear Regression

Many Applications

• Channel Coding in Communicationsy = received sample w = noise/interferenceA = coding dictionary β0 = message

• Imaging: Medical, Seismic, Compressive Sensing...y = measurements w = sensor noiseA = basis representation β0 = sparse image/signal

• Statistics/Machine Learningy = experimental outcome w = model errorA = feature data β0 = prediction coefficients

Problem sizes are large, computational complexity ofreconstruction algorithm is a concern.

3 / 31

High-dimensional Linear Regression

= yβ0A + wm

N

Goal: reconstruct k-sparse β0 from y = Aβ0 + w

Want to solve:

β = arg minβ∈RN

‖y − Aβ‖2 subject toN∑j=1

I{βj 6= 0} ≤ k.

Unfortunately, a very hard combinatorial problem.

4 / 31

High-dimensional Linear Regression

= yβ0A + wm

N

Goal: reconstruct k-sparse β0 from y = Aβ0 + w

Want to solve:

β = arg minβ∈RN

‖y − Aβ‖2 subject to ‖β‖0 ≤ k .

Unfortunately, a very hard combinatorial problem.

4 / 31

High-dimensional Linear Regression

Want to solve:

β = arg minβ∈RN

‖y − Aβ‖2 subject to ‖β‖0 ≤ k .

Unfortunately, a very hard combinatorial problem.

Instead, a convex relaxation:

β = arg minβ∈RN

‖y − Aβ‖2 subject to ‖β‖1 ≤ λ.

If A satisfies certain conditions (e.g., RIP) then can get a goodestimate of sparse β0 by solving a convex program (LASSO):

β = arg minβ∈RN

‖y − Aβ‖2 + λ‖β‖1

[Donoho ’06, Candes-Romberg-Tao’06, . . . ]

5 / 31

High-dimensional Linear Regression

Want to solve:

β = arg minβ∈RN

‖y − Aβ‖2 subject to ‖β‖0 ≤ k .

Unfortunately, a very hard combinatorial problem.

Instead, a convex relaxation:

β = arg minβ∈RN

‖y − Aβ‖2 subject to ‖β‖1 ≤ λ.

If A satisfies certain conditions (e.g., RIP) then can get a goodestimate of sparse β0 by solving a convex program (LASSO):

β = arg minβ∈RN

‖y − Aβ‖2 + λ‖β‖1

[Donoho ’06, Candes-Romberg-Tao’06, . . . ]

5 / 31

Approximate Message Passing (AMP)

AMP: low complexity, scalable algorithm studied to solve thehigh-dimensional linear regression task for compressed sensing.

Outline

1. AMP algorithm for the LASSO.• Derivation from message passing• Comparison to other LASSO solvers

2. General AMP algorithms.

3. State evolution and performance guarantees.

4. Generalizations and extensions of AMP.

6 / 31

Algorithmic Challenges

Want to solve:

β = arg minβ∈RN

‖y − Aβ‖2 + λ‖β‖1

Convex Optimization Tools for Solving the LASSO

1. Classic Interior Point Method:• Usually matrix-matrix multiplication or matrix decomposition• Appropriate for, say, N < 5000

2. Homotopy Methods (e.g. LARS)• Use the structure of the LASSO cost• Appropriate for, say, N < 50000

3. First-order Methods• Low computational complexity per iteration• Require many iterations

7 / 31

Solving the LASSO

β = arg minβ

‖y − Aβ‖2 + λ‖β‖1

First-order methods: Iteratively generate estimates of β0

β1, β2, . . .

1. Proximal Gradient (aka Iterative Soft-Thresholding)

r t = y − Aβt

βt+1 = η(βt + sAT r t ; sλ)

η(x ;T ) =

x − T , x ≥ T0, −T < x < Tx + T , x ≤ −T

x−T T

8 / 31

Solving the LASSO

2. Proximal Gradient + Momentum (FISTA/Nesterov)

momentum term βt = βt +t − 1

t + 2(βt − βt−1)

same as IST r t = y − Aβt

same as IST βt+1 = η(βt + sAT r t ; sλ)

FISTA is good, but want faster convergence as N grows large

Can we use a message passing algorithm?

9 / 31

Assumptions

y = Aβ0 + w

• Let us first assume that the entries of A are iid N (0, 1m )

• Dimensions of A: m,N large, mN → δ (δ is Θ(1))

AMP derived as approximation of loopy belief propagationfor dense graphs

[Donoho-Maleki-Montanari ’09], [Rangan ’11], [Krzakala et al ’12], [Schniter ’11], . . .

10 / 31

Min-Sum Message PassingWant to compute β = arg minβ

∑ma=1 (ya − [Aβ]a)2 + λ

∑Ni=1|βi |

β1 β2 βN

Factor Graph Representation fo LASSO Cost

• m factor nodes corresponding to(y1 − [Aβ]1)2, (y2 − [Aβ]2)2, . . . , (ym − [Aβ]m)2

• N variable nodes corresponding to β1, β2, . . . , βN

• Edge between factor node a and variable node i if Aa,i 6= 0.

11 / 31

Min-Sum Message Passing

Want to compute β = arg minβ∑m

a=1 (ya − [Aβ]a)2 + λ∑N

i=1|βi |

β1 β2 βN

Factor Graph Representation fo LASSO Cost

• Min-sum – popular optimization algorithm forgraph-structured cost

• AMP derived from min-sum on above graph

11 / 31

Min-Sum Message PassingWant to compute β = arg minβ

∑ma=1 (ya − [Aβ]a)2 + λ

∑Ni=1|βi |

β1 β2 βN

For i = 1, . . . ,N and a = 1, 2, . . . ,m:

Mti→a(βi ) = λ|βi |+

∑b∈[m]\a

Mt−1b→i (βi )

Mta→i (βi ) = min

β\βi

(ya − [Aβ]a)2 +∑

j∈[N]\i

Mtj→a(βj)

11 / 31

Min-Sum Message PassingWant to compute β = arg minβ

∑ma=1 (ya − [Aβ]a)2 + λ

∑Ni=1|βi |

β1 β2 βN

For i = 1, . . . ,N and a = 1, 2, . . . ,m:

Mti→a(βi ) = λ|βi |+

∑b∈[m]\a

Mt−1b→i (βi )

Mta→i (βi ) = min

β\βi

(ya − [Aβ]a)2 +∑

j∈[N]\i

Mtj→a(βj)

11 / 31

Why Min-sum?

• Vast literature justifying and studying use of min-sum. Forexample [Murphy, Weiss, Jordan ’99].

• Computes exact minimum when graph is tree (no cycles)

• Generally not guaranteed to converge on ’loopy’ graphs

• Nonetheless works well in some ’loopy’ applications (coding,machine vision, compressed sensing, ...)

β1 β2 βN

12 / 31

Why Min-sum?

• Vast literature justifying and studying use of min-sum. Forexample [Murphy, Weiss, Jordan ’99].

• Computes exact minimum when graph is tree (no cycles)

• Generally not guaranteed to converge on ’loopy’ graphs

• Nonetheless works well in some ’loopy’ applications (coding,machine vision, compressed sensing, ...)

Further Limitations

But computing these messages is infeasible:

— Each message needs to be computed for all βi ∈ R— There are mN such messages

12 / 31

Quadratic Approximation of Messages

β1 β2 βN

Messages approximated by two numbers (via quadraticapproximation):

r ta→i = ya −∑

j∈[N]\i

Aaj βtj→a

βt+1i→a = η

( ∑b∈[m]\a

Abi rtb→i ; θt

)

13 / 31

Quadratic Approximation of Messages

β1 β2 βN

Messages approximated by two numbers (via quadraticapproximation):

r ta→i = ya −∑

j∈[N]\i

Aaj βtj→a βt+1

i→a = η

( ∑b∈[m]\a

Abi rtb→i ; θt

)

We still have mN messages in each step . . .13 / 31

r ta→i = ya −∑j∈[N]

Aajβtj→a + Aaiβ

ti→a

βt+1i→a = η

( ∑b∈[k]

Abi rtb→i − Aai r

ta→i ; θt

)

• Weak dependence between messages and target indices

• Neglecting dependence altogether gives IST

• A more careful analysis, using Taylor approximations . . .

14 / 31

The AMP algorithm

AMP iteratively produces estimates β0 = 0, β1, . . . , βt , . . .

r t = y − Aβt +r t−1

m‖βt‖0

βt+1 = η(βt + AT r t ; θt)

• r t is the ‘modified residual’ after step t

• η denoises the effective observation to produce βt+1

Compare to Iterative Soft-Thresholding

r t = y − Aβt

βt+1 = η(βt+sAT r t ; sλ)

15 / 31

The AMP algorithm

AMP iteratively produces estimates β0 = 0, β1, . . . , βt , . . .

r t = y − Aβt +r t−1

m‖βt‖0

βt+1 = η(βt + AT r t ; θt)

• r t is the ‘modified residual’ after step t

• η denoises the effective observation to produce βt+1

Compare to Iterative Soft-Thresholding

r t = y − Aβt

βt+1 = η(βt+sAT r t ; sλ)

15 / 31

The AMP algorithmAMP iteratively produces estimates β0 = 0, β1, . . . , βt , . . .

r t = y − Aβt +r t−1

m‖βt‖0

βt+1 = η(AT r t + βt ; θt)

With the assumptions:

• Entries of A are iid N (0, 1m )

• Dimensions of A: m,N large, mN → δ (δ is constant)

The momentum term in r t ensures that asymptotically

AT r t + βt ≈ β0 + τtZ where Z is N (0, 1)

⇒ The effective observation AT r t + βt is true signal observed inindependent Gaussian noise.

16 / 31

Example: y = Aβ0

A : m × N = 2000× 4000; β0 has 500 non-zeros ∼ iid unif ±1

Histogram of AT r t + βt at t = 10

with r t = y − ATβt + r t−1 ‖βt‖0

m with r t = y − ATβt

17 / 31

Example: y = Aβ0

A : m × N = 2000× 4000; β0 has 500 non-zeros ∼ iid unif ±1

Histogram of AT r t + βt at t = 10

– Here: empirical observation at a single t for specific m,N– Later: rigorous proof that statistical properties exact in limit ofm,N for all t

17 / 31

AMP vs the Rest

y = Aβ0 + w , w iid ∼ N (0, σ2), MSE =1

N‖βt − β0‖2

18 / 31

General AMP Framework

In the talk so far:

• Sparse signal (unknown signal prior distribution)

• Goal to minimize LASSO cost

Generalization:

• Known signal prior distribution (sparsity-inducing or not)

• Goal to minimize mean squared error (MSE)

Let y = Aβ0 + w , β0 iid ∼ pβ, w iid ∼ N (0, σ2)

r t = y − Aβt +r t−1

m

N∑i=1

η′t−1(AT r t−1 + βt−1)i

βt+1 = ηt(AT r t + βt)

Function ηt chosen to denoise effective observation producing βt+1

19 / 31

General AMP Framework

In the talk so far:

• Sparse signal (unknown signal prior distribution)

• Goal to minimize LASSO cost

Generalization:

• Known signal prior distribution (sparsity-inducing or not)

• Goal to minimize mean squared error (MSE)

Let y = Aβ0 + w , β0 iid ∼ pβ, w iid ∼ N (0, σ2)

r t = y − Aβt +r t−1

m

N∑i=1

η′t−1(AT r t−1 + βt−1)i

βt+1 = ηt(AT r t + βt)

Function ηt chosen to denoise effective observation producing βt+1

19 / 31

Choosing ηt(·)Let y = Aβ0 + w , β0 iid ∼ pβ, w iid ∼ N (0, σ2)

r t = y − Aβt +r t−1

m

N∑i=1

η′t−1(AT r t−1 + βt−1)i

βt+1 = ηt(AT r t + βt)

KEY: For large m,N, at each time step t

AT r t + βt ≈ β0 + τt Z where Z is N (0, 1)

• pβ known: Bayes-optimal ηt choice minimizes E‖β0 − βt+1‖2.Equals

ηt(s) = E[β0 | β0 + τtZ = s]

• pβ unknown: partial knowledge about β0 can guide ηt choice.

20 / 31

General AMP Framework

To summarize:

LASSO:

• Sparse signal (unknown signal prior distribution)

• Goal to minimize LASSO cost

• Use denoiser η(·) as soft-threshold

Generalization:

• Known signal prior distribution (sparsity-inducing or not)

• Goal to minimize mean squared error (MSE)

• Use denoiser ηt(s) = E[β0 |β0 +N (0, τt) = s].

In both cases, AT r t + βt ≈ β0 +N (0, τt).

Choice of denoiser determines the type of problem AMP solves.

21 / 31

The Modified Residual [Donoho-Maleki-Montanari ’09]

Assume Aij ∼ N (0, 1/m) and wi ∼ N (0, σ2).

Suppose instead,

r t = y − Aβt

+r t−1

m

N∑i=1

η′t

([AT r t−1 + βt−1]i

)

22 / 31

The Modified Residual [Donoho-Maleki-Montanari ’09]

Assume Aij ∼ N (0, 1/m) and wi ∼ N (0, σ2).

Suppose instead,

r t = y − Aβt

+r t−1

m

N∑i=1

η′t

([AT r t−1 + βt−1]i

)

Then effective observation:

βt + AT r t = β0 + ATw + (I− ATA)(β0 − βt)

22 / 31

The Modified Residual [Donoho-Maleki-Montanari ’09]

Assume Aij ∼ N (0, 1/m) and wi ∼ N (0, σ2).

Suppose instead,

r t = y − Aβt

+r t−1

m

N∑i=1

η′t

([AT r t−1 + βt−1]i

)

Then effective observation:

βt + AT r t = β0 + ATw︸ ︷︷ ︸≈N (0,σ2)

+ (I− ATA)(β0 − βt)

22 / 31

The Modified Residual [Donoho-Maleki-Montanari ’09]

Assume Aij ∼ N (0, 1/m) and wi ∼ N (0, σ2).

Suppose instead,

r t = y − Aβt

+r t−1

m

N∑i=1

η′t

([AT r t−1 + βt−1]i

)

Then effective observation:

βt + AT r t = β0 + ATw︸ ︷︷ ︸≈N (0,σ2)

+ (I− ATA)︸ ︷︷ ︸≈N (0,1/m)

(β0 − βt)

22 / 31

The Modified Residual [Donoho-Maleki-Montanari ’09]

Assume Aij ∼ N (0, 1/m) and wi ∼ N (0, σ2).

Suppose instead,

r t = y − Aβt

+r t−1

m

N∑i=1

η′t

([AT r t−1 + βt−1]i

)

Then effective observation:

βt + AT r t = β0 + ATw︸ ︷︷ ︸≈N (0,σ2)

+ (I− ATA)︸ ︷︷ ︸≈N (0,1/m)

(β0 − βt)

≈ β0 +

√σ2 +

E‖β0 − βt‖2

mZ

22 / 31

The Modified Residual [Donoho-Maleki-Montanari ’09]

Assume Aij ∼ N (0, 1/m) and wi ∼ N (0, σ2).

Suppose instead,

r t = y − Aβt +r t−1

m

N∑i=1

η′t

([AT r t−1 + βt−1]i

)Then effective observation:

βt + AT r t = β0 + ATw︸ ︷︷ ︸≈N (0,σ2)

+ (I− ATA)︸ ︷︷ ︸≈N (0,1/m)

(β0 − βt)

≈ β0 +

√σ2 +

E‖β0 − βt‖2

mZ

22 / 31

State EvolutionDefine τ2

t as noise variance in the effective observation after step t.

βt + AT r t ≈ β0 + τtZ , Z ∼ N (0, I).

If τ1, τ2, . . . is decreasing, getting a more ’pure’ view of β0 asalgorithm iterates.

SE Equations

Set τ20 = σ2 + E‖β‖2

m ,

τ2t = σ2 +

E‖β − βt‖2

m=

σ2 +E‖β − ηt(β + τt−1Z )‖2

m

Z ∼ N (0, 1) independent of β ∼ pβ.

State evolution is a scalar recursion that allows us to predict theperformance of AMP at every iteration!

23 / 31

State EvolutionDefine τ2

t as noise variance in the effective observation after step t.

βt + AT r t ≈ β0 + τtZ , Z ∼ N (0, I).

If τ1, τ2, . . . is decreasing, getting a more ’pure’ view of β0 asalgorithm iterates.

SE Equations

Set τ20 = σ2 + E‖β‖2

m ,

τ2t = σ2 +

E‖β − βt‖2

m= σ2 +

E‖β − ηt(β + τt−1Z )‖2

m

Z ∼ N (0, 1) independent of β ∼ pβ.

State evolution is a scalar recursion that allows us to predict theperformance of AMP at every iteration!

23 / 31

Assumptions for Performance Guarantees

We make the following assumptions:

• Measurement matrix: i.i.d. ∼ N (0, 1/m).

• Signal: i.i.d. ∼ pβ, sub-Gaussian.

• Measurement noise: i.i.d. ∼ pW , sub-Gaussian, E[w2i ] = σ2.

• De-noising Functions ηt : Lipschitz continuous with weakderivative η′t which is differentiable except possibly at a finitenumber of points, with bounded derivative everywhere itexists.

24 / 31

Performance Guarantees

Theorem (Rush, Venkataramanan ’16)

Under the assumptions of the previous slide, with constants Kt , κt ,for ∆ ∈ (0, 1) and t ≥ 0,

P

(∣∣∣∣ 1

m‖βt+1 − β0‖2 − (τ2

t+1 − σ2)

∣∣∣∣ ≥ ∆

)≤ Kte

−κtN∆2.

Constants in the Bound:

• Constants Kt = K1(K2)t(t!)10 and κt = κ1κ−t2 (t!)−18

where K1,K2, κ1, κ2 > 0 are universal constants.

• Indicates how large t can get for deviation prob. → 0:

t = o(

log Nlog log N

)

• Result holds for more general class of loss functions(beyond MSE).

• Refines an asymptotic result proved by Bayati,Montanari [Trans. IT ’11]

• The finite-sample result above implies the asymptoticresult (via Borel-Cantelli), i.e. with δ = m/N

limN→∞

1

N‖βt+1 − β0‖2 a.s.

= δ(τ2t+1 − σ2).

25 / 31

Performance Guarantees

Theorem (Rush, Venkataramanan ’16)

Under the assumptions of the previous slide, with constants Kt , κt ,for ∆ ∈ (0, 1) and t ≥ 0,

P

(∣∣∣∣ 1

m‖βt+1 − β0‖2 − (τ2

t+1 − σ2)

∣∣∣∣ ≥ ∆

)≤ Kte

−κtN∆2.

• Result holds for more general class of loss functions(beyond MSE).

• Refines an asymptotic result proved by Bayati,Montanari [Trans. IT ’11]

• The finite-sample result above implies the asymptoticresult (via Borel-Cantelli), i.e. with δ = m/N

limN→∞

1

N‖βt+1 − β0‖2 a.s.

= δ(τ2t+1 − σ2).

25 / 31

Back to LASSO

It can be shown [Bayati, Montanari ’12],

limt→∞

limN→∞

1

N‖βt − β‖2 a.s.

= 0,

for β, the LASSO minimizer, and βt , the AMP estimate at time t.

(AMP threshold has one-to-one map with LASSO parameter λ. Assumesi.i.d. Gaussian A.)

Moreover, AMP performance guarantees with the above imply anasymptotic result for the LASSO minimizer:

limN→∞

1

N‖β0 − β‖2 a.s.

= δ(τ2∗ − σ2),

where τ2∗ = limt→∞ τ

2t with τt given by state evolution.

26 / 31

Proof Idea of Performance Guarantees

Show βt + AT r t ∼ β0 + τtZ , with τt given by state evolution.

Steps:

1. Characterize the conditional distribution of the effectiveobservation and residual as sum of i.i.d. Gaussians plusdeviation term.

Show:

(βt + AT r t − β0)|{past, β0,w}d= τtZt + ∆t ,

(r t − w)|{past, β0,w}d=√τ2t − σ2 Zt + ∆t ,

2. Concentration results show the deviation terms aresmall with high probability.

27 / 31

Proof Idea of Performance Guarantees

Show βt + AT r t ∼ β0 + τtZ , with τt given by state evolution.

Steps:

1. Characterize the conditional distribution of the effectiveobservation and residual as sum of i.i.d. Gaussians plusdeviation term.

Show:

(βt + AT r t − β0)|{past, β0,w}d= τtZt + ∆t ,

(r t − w)|{past, β0,w}d=√τ2t − σ2 Zt + ∆t ,

2. Concentration results show the deviation terms aresmall with high probability.

27 / 31

Proof Idea of Performance Guarantees

Show βt + AT r t ∼ β0 + τtZ , with τt given by state evolution.

Steps:

1. Characterize the conditional distribution of the effectiveobservation and residual as sum of i.i.d. Gaussians plusdeviation term.

Show:

(βt + AT r t − β0)|{past, β0,w}d= τtZt + ∆t ,

(r t − w)|{past, β0,w}d=√τ2t − σ2 Zt + ∆t ,

2. Concentration results show the deviation terms aresmall with high probability.

27 / 31

AMP Extensions/Generalizations

• Non-Gaussian noise distributions (GAMP) [Rangan ’11]

• Different measurement matrices:

• Sub-Gaussian [Bayati, Lelarge, Montanari ’15]

• Right orthogonally-invariant (VAMP) [Schniter, Rangan,Fletcher ’16, ’17]

• Spatially-coupled (for improved MSE performance)[Donoho, Javanmard, Montanari ’13]

• Signals with dependent entries and non-separabledenoisers [Ma, Rush, Baron ’17], [Berthier, Montanari,Nguyen ’17]

28 / 31

AMP Extensions/Generalizations

Different measurement models:• Bilinear Models [Parker, Schniter, Cevhar ’14]

• Multiple Measurement Vectors [Ziniel, Schniter ’13]

• Matrix Factorization [Kabashima, Krzakala, Mezard, Sakata,Zdeborova ’16]

• Blind Deconvolution

• Low-rank Matrix Estimation [Rangan, Fletcher ’12], [Lesieur,Krzakala, Zdeborova ’15]

• Principle Component Analysis [Deshpandre, Montanari ’14],[Montanari, Richard ’16]

• Stochastic Block Model [Deshpandre, Abbe, Montanari ’16]

• Replica Method [Barbier, Dia, Macris, Krzakala, Lesieur,Zdeborova ’15]

29 / 31

AMP Summaryy = Aβ0 + w

r t = y − Aβt +r t−1

m

N∑i=1

η′t−1(AT r t−1 + βt−1)i

βt+1 = ηt(AT r t + βt)

AMP: First-order iterative algorithm

• Theory assumes Gaussian A and iid/exchangeable pβ• Sharp theoretical guarantees determined by simple scalar

iteration. E.g.,

1

N‖β0 − βt+1‖2 ≈ δ(τ2

t+1 − σ2)

• AMP can be run even without knowing pβ

(our result shows that τ2t concentrates on ‖r t‖2/m)

• Knowing pβ can help choose a good denoiser ηt30 / 31

Open Questions

• Theoretical results for general A matrices

(iid uniform Bernoulli, partial DFT, . . . )

• Connections between AMP and classical optimizationtechniques

AMP

r t = y − Aβt + r t−1 ‖βt‖0

m

βt+1 = η(AT r t + βt ; ατt)

Nesterov/FISTA

βt = βt +t − 1

t + 2(βt − βt−1)

r t = y − Aβt

βt+1 = η(βt + sAT r t ; sλ)

31 / 31

top related