chainer の trainer 解説と nsteplstm について

50
Chainer の Trainer 解説と NStepLSTM について 株式会社レトリバ © 2017 Retrieva, Inc.

Upload: retrieva-inc

Post on 21-Apr-2017

746 views

Category:

Engineering


6 download

TRANSCRIPT

Page 1: Chainer の Trainer 解説と NStepLSTM について

Chainer の Trainer 解説とNStepLSTM について株式会社レトリバ

© 2017 Retrieva, Inc.

Page 2: Chainer の Trainer 解説と NStepLSTM について

⾃⼰紹介• ⽩⼟慧(シラツチ ケイ)• 株式会社レトリバ• 2016年4⽉⼊社

• Ruby on Rails / JavaScript• フロントエンド側の⼈間

• ⼤学時代は複雑ネットワーク科学の研究• Chainer ⼊⾨中

© 2017 Retrieva, Inc. 2

Page 3: Chainer の Trainer 解説と NStepLSTM について

アジェンダ• 第1部 Chainer における Trainer の解説• 第2部 NStepLSTM との格闘

© 2017 Retrieva, Inc. 3

Page 4: Chainer の Trainer 解説と NStepLSTM について

アンケート• Chainer を使っている⽅• Chainer の Trainer を使っている⽅• LSTM を使っている⽅• NStepLSTM を使っている⽅• NStepLSTM と Trainer を使っている⽅

© 2017 Retrieva, Inc. 4

Page 5: Chainer の Trainer 解説と NStepLSTM について

第1部 Chainer における Trainer

© 2017 Retrieva, Inc. 5

Page 6: Chainer の Trainer 解説と NStepLSTM について

Chainer における Trainer• Chainer 1.11.0 から導⼊された学習フレームワーク• batchの取り出し、forward/backward が抽象化されている• 進捗表⽰、モデルのスナップショットなど

• Trainer 後から⼊⾨した⼈(私も)は、MNIST のサンプルがTrainerで抽象化されていて、何が起きているのかわからない• 以前から Chainer を使っている⼈は、Trainer なしで動かして

いることが多い

© 2017 Retrieva, Inc. 6

Page 7: Chainer の Trainer 解説と NStepLSTM について

Trainer 全体図 (train_mnist.py)

© 2017 Retrieva, Inc. 7

Trainer

Updater (StandardUpdater)

Iterator

Optimizer

Classifier

model (MLP)

train dataset

EvaluatorIterator

test dataset

Extensions

• dump_graph, snapshot• LogReport, PrintReport• ProgressBar

• converter• loss_func• device

• converter• device

Page 8: Chainer の Trainer 解説と NStepLSTM について

Trainer• Trainer フレームワークの⼤元• 渡された Updater、(必要があれば) Evaluator を実⾏する• グラフのダンプ、スナップショット、レポーティング、進捗表

⽰などを、Extension として実⾏できる

© 2017 Retrieva, Inc. 8

Page 9: Chainer の Trainer 解説と NStepLSTM について

Trainer

• 指定した epoch 数になるまで、Updater の update() を呼ぶ

© 2017 Retrieva, Inc. 9

# examples/mnist/train_mnist.pytrainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

# chainer/training/trainer.pyclass Trainer(object):

def run(self):update = self.updater.update# main training looptry:

while not stop_trigger(self):self.observation = {}with reporter.scope(self.observation):

update()

Page 10: Chainer の Trainer 解説と NStepLSTM について

Updater

© 2017 Retrieva, Inc. 10

Trainer

Updater (StandardUpdater)

Iterator

Optimizer

Classifier

model (MLP)

train dataset

EvaluatorIterator

test dataset

Extensions

• dump_graph, snapshot• LogReport, PrintReport• ProgressBar

• converter• loss_func• device

• converter• device

Page 11: Chainer の Trainer 解説と NStepLSTM について

Updater• ⼊⼒を逐次実⾏する• ⼊⼒の Iterator と、Optimizer を持つ• Iterator から⼀つずつデータを読み込み、変換し、Optimizer に

かける

© 2017 Retrieva, Inc. 11

Page 12: Chainer の Trainer 解説と NStepLSTM について

Updater

© 2017 Retrieva, Inc. 12

# examples/mnist/train_mnist.pyupdater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)

# chainer/training/updater.pyclass StandardUpdater(Updater):

def update_core(self):batch = self._iterators['main'].next()in_arrays = self.converter(batch, self.device)

optimizer = self._optimizers['main']loss_func = self.loss_func or optimizer.target

if isinstance(in_arrays, tuple):in_vars = tuple(variable.Variable(x) for x in in_arrays)optimizer.update(loss_func, *in_vars)

Page 13: Chainer の Trainer 解説と NStepLSTM について

Updater

© 2017 Retrieva, Inc. 13

# examples/mnist/train_mnist.pyupdater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)

# chainer/training/updater.pyclass StandardUpdater(Updater):

def update_core(self):batch = self._iterators['main'].next()in_arrays = self.converter(batch, self.device)

optimizer = self._optimizers['main']loss_func = self.loss_func or optimizer.target

if isinstance(in_arrays, tuple):in_vars = tuple(variable.Variable(x) for x in in_arrays)optimizer.update(loss_func, *in_vars)

Iterator から⼀つ呼び出す

Page 14: Chainer の Trainer 解説と NStepLSTM について

Updater

© 2017 Retrieva, Inc. 14

# examples/mnist/train_mnist.pyupdater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)

# chainer/training/updater.pyclass StandardUpdater(Updater):

def update_core(self):batch = self._iterators['main'].next()in_arrays = self.converter(batch, self.device)

optimizer = self._optimizers['main']loss_func = self.loss_func or optimizer.target

if isinstance(in_arrays, tuple):in_vars = tuple(variable.Variable(x) for x in in_arrays)optimizer.update(loss_func, *in_vars)

Iterator から⼀つ呼び出す

Converter にかける(変換し、to_gpu する)

Page 15: Chainer の Trainer 解説と NStepLSTM について

Updater

© 2017 Retrieva, Inc. 15

# examples/mnist/train_mnist.pyupdater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)

# chainer/training/updater.pyclass StandardUpdater(Updater):

def update_core(self):batch = self._iterators['main'].next()in_arrays = self.converter(batch, self.device)

optimizer = self._optimizers['main']loss_func = self.loss_func or optimizer.target

if isinstance(in_arrays, tuple):in_vars = tuple(variable.Variable(x) for x in in_arrays)optimizer.update(loss_func, *in_vars)

Iterator から⼀つ呼び出す

Converter にかける(変換し、to_gpu する)

Optimizer のupdate を呼ぶ

Page 16: Chainer の Trainer 解説と NStepLSTM について

Iterator

© 2017 Retrieva, Inc. 16

Trainer

Updater (StandardUpdater)

Iterator

Optimizer

Classifier

model (MLP)

train dataset

EvaluatorIterator

test dataset

Extensions

• dump_graph, snapshot• LogReport, PrintReport• ProgressBar

• converter• loss_func• device

• converter• device

Page 17: Chainer の Trainer 解説と NStepLSTM について

Iterator

© 2017 Retrieva, Inc. 17

# chainer/iterators/serial_iterator.pyclass SerialIterator(iterator.Iterator):

def __next__(self):...return batch

@propertydef epoch_detail(self):

return self.epoch + self.current_position / len(self.dataset)

# examples/mnist/train_mnist.pytrain, test = chainer.datasets.get_mnist()train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

• Iterator として、batch を返す• 回している回数の管理をする

Page 18: Chainer の Trainer 解説と NStepLSTM について

Optimizer

© 2017 Retrieva, Inc. 18

Trainer

Updater (StandardUpdater)

Iterator

Optimizer

Classifier

model (MLP)

train dataset

EvaluatorIterator

test dataset

Extensions

• dump_graph, snapshot• LogReport, PrintReport• ProgressBar

• converter• loss_func• device

• converter• device

Page 19: Chainer の Trainer 解説と NStepLSTM について

Optimizer• ⼊⼒データを model に forward し、返り値の loss を

backward する• 最適化アルゴリズムごとに実装がある• SGD, MomentumSGD, Adam, …• Optimizer で抽象化されている

© 2017 Retrieva, Inc. 19

Page 20: Chainer の Trainer 解説と NStepLSTM について

Optimizer

© 2017 Retrieva, Inc. 20

# chainer/training/updater.pyloss_func = self.loss_func or optimizer.targetoptimizer.update(loss_func, *in_vars)

# examples/mnist/train_mnist.pyoptimizer = chainer.optimizers.Adam()optimizer.setup(model)

# chainer/optimizer.pyclass GradientMethod(Optimizer):

def update(self, lossfun=None, *args, **kwds):if lossfun is not None:

use_cleargrads = getattr(self, '_use_cleargrads', False)loss = lossfun(*args, **kwds)if use_cleargrads:

self.target.cleargrads()else:

self.target.zerograds()loss.backward()

Page 21: Chainer の Trainer 解説と NStepLSTM について

Optimizer

© 2017 Retrieva, Inc. 21

# chainer/training/updater.pyloss_func = self.loss_func or optimizer.targetoptimizer.update(loss_func, *in_vars)

# examples/mnist/train_mnist.pyoptimizer = chainer.optimizers.Adam()optimizer.setup(model)

# chainer/optimizer.pyclass GradientMethod(Optimizer):

def update(self, lossfun=None, *args, **kwds):if lossfun is not None:

use_cleargrads = getattr(self, '_use_cleargrads', False)loss = lossfun(*args, **kwds)if use_cleargrads:

self.target.cleargrads()else:

self.target.zerograds()loss.backward()

target は渡された model

(Classifier)

Page 22: Chainer の Trainer 解説と NStepLSTM について

Optimizer

© 2017 Retrieva, Inc. 22

# chainer/training/updater.pyloss_func = self.loss_func or optimizer.targetoptimizer.update(loss_func, *in_vars)

# examples/mnist/train_mnist.pyoptimizer = chainer.optimizers.Adam()optimizer.setup(model)

# chainer/optimizer.pyclass GradientMethod(Optimizer):

def update(self, lossfun=None, *args, **kwds):if lossfun is not None:

use_cleargrads = getattr(self, '_use_cleargrads', False)loss = lossfun(*args, **kwds)if use_cleargrads:

self.target.cleargrads()else:

self.target.zerograds()loss.backward()

target は渡された model

(Classifier)

model に forward

Page 23: Chainer の Trainer 解説と NStepLSTM について

Optimizer

© 2017 Retrieva, Inc. 23

# chainer/training/updater.pyloss_func = self.loss_func or optimizer.targetoptimizer.update(loss_func, *in_vars)

# examples/mnist/train_mnist.pyoptimizer = chainer.optimizers.Adam()optimizer.setup(model)

# chainer/optimizer.pyclass GradientMethod(Optimizer):

def update(self, lossfun=None, *args, **kwds):if lossfun is not None:

use_cleargrads = getattr(self, '_use_cleargrads', False)loss = lossfun(*args, **kwds)if use_cleargrads:

self.target.cleargrads()else:

self.target.zerograds()loss.backward()

target は渡された model

(Classifier)

model に forward

backward を実⾏

Page 24: Chainer の Trainer 解説と NStepLSTM について

Classifier

© 2017 Retrieva, Inc. 24

Trainer

Updater (StandardUpdater)

Iterator

Optimizer

Classifier

model (MLP)

train dataset

EvaluatorIterator

test dataset

Extensions

• dump_graph, snapshot• LogReport, PrintReport• ProgressBar

• converter• loss_func• device

• converter• device

Page 25: Chainer の Trainer 解説と NStepLSTM について

Classifier• 教師あり学習⽤の model のラッパー• ⼊⼒と正解データから、loss と accuracy を計算する

© 2017 Retrieva, Inc. 25

Page 26: Chainer の Trainer 解説と NStepLSTM について

Classifier

© 2017 Retrieva, Inc. 26

# examples/mnist/train_mnist.pymodel = L.Classifier(MLP(args.unit, 10))

# chainer/links/model/classifier.pyclass Classifier(link.Chain):

def __init__(self, predictor,lossfun=softmax_cross_entropy.softmax_cross_entropy,accfun=accuracy.accuracy):

def __call__(self, *args):self.y = self.predictor(*x)self.loss = self.lossfun(self.y, t)reporter.report({'loss': self.loss}, self)if self.compute_accuracy:

self.accuracy = self.accfun(self.y, t)reporter.report({'accuracy': self.accuracy}, self)

return self.loss

Page 27: Chainer の Trainer 解説と NStepLSTM について

Classifier

© 2017 Retrieva, Inc. 27

# examples/mnist/train_mnist.pymodel = L.Classifier(MLP(args.unit, 10))

# chainer/links/model/classifier.pyclass Classifier(link.Chain):

def __init__(self, predictor,lossfun=softmax_cross_entropy.softmax_cross_entropy,accfun=accuracy.accuracy):

def __call__(self, *args):self.y = self.predictor(*x)self.loss = self.lossfun(self.y, t)reporter.report({'loss': self.loss}, self)if self.compute_accuracy:

self.accuracy = self.accfun(self.y, t)reporter.report({'accuracy': self.accuracy}, self)

return self.loss

損失関数を指定する

Page 28: Chainer の Trainer 解説と NStepLSTM について

Classifier

© 2017 Retrieva, Inc. 28

# examples/mnist/train_mnist.pymodel = L.Classifier(MLP(args.unit, 10))

# chainer/links/model/classifier.pyclass Classifier(link.Chain):

def __init__(self, predictor,lossfun=softmax_cross_entropy.softmax_cross_entropy,accfun=accuracy.accuracy):

def __call__(self, *args):self.y = self.predictor(*x)self.loss = self.lossfun(self.y, t)reporter.report({'loss': self.loss}, self)if self.compute_accuracy:

self.accuracy = self.accfun(self.y, t)reporter.report({'accuracy': self.accuracy}, self)

return self.loss

損失関数を指定する

model に forward

Page 29: Chainer の Trainer 解説と NStepLSTM について

Classifier

© 2017 Retrieva, Inc. 29

# examples/mnist/train_mnist.pymodel = L.Classifier(MLP(args.unit, 10))

# chainer/links/model/classifier.pyclass Classifier(link.Chain):

def __init__(self, predictor,lossfun=softmax_cross_entropy.softmax_cross_entropy,accfun=accuracy.accuracy):

def __call__(self, *args):self.y = self.predictor(*x)self.loss = self.lossfun(self.y, t)reporter.report({'loss': self.loss}, self)if self.compute_accuracy:

self.accuracy = self.accfun(self.y, t)reporter.report({'accuracy': self.accuracy}, self)

return self.loss

損失関数を指定する

model に forward

loss を算出

accuracy を算出

Page 30: Chainer の Trainer 解説と NStepLSTM について

Classifier

© 2017 Retrieva, Inc. 30

# examples/mnist/train_mnist.pymodel = L.Classifier(MLP(args.unit, 10))

# chainer/links/model/classifier.pyclass Classifier(link.Chain):

def __init__(self, predictor,lossfun=softmax_cross_entropy.softmax_cross_entropy,accfun=accuracy.accuracy):

def __call__(self, *args):self.y = self.predictor(*x)self.loss = self.lossfun(self.y, t)reporter.report({'loss': self.loss}, self)if self.compute_accuracy:

self.accuracy = self.accfun(self.y, t)reporter.report({'accuracy': self.accuracy}, self)

return self.loss

損失関数を指定する

model に forward

loss を算出

accuracy を算出

loss を返す

Page 31: Chainer の Trainer 解説と NStepLSTM について

Evaluator

© 2017 Retrieva, Inc. 31

Trainer

Updater (StandardUpdater)

Iterator

Optimizer

Classifier

model (MLP)

train dataset

EvaluatorIterator

test dataset

Extensions

• dump_graph, snapshot• LogReport, PrintReport• ProgressBar

• converter• loss_func• device

• converter• device

Page 32: Chainer の Trainer 解説と NStepLSTM について

Evaluator• テストデータに対して、loss, accuracy などを計算し、検証す

る• epoch ごとに、現在まで学習された model に対して検証する• ⼤まかには、Updater と対応している

© 2017 Retrieva, Inc. 32

Page 33: Chainer の Trainer 解説と NStepLSTM について

Evaluator

© 2017 Retrieva, Inc. 33

# examples/mnist/train_mnist.pytest_iter = chainer.iterators.SerialIterator(test, args.batchsize,

repeat=False, shuffle=False)trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

# chainer/training/extensions/evaluator.pyclass Evaluator(extension.Extension):

def evaluate(self):iterator = self._iterators['main']target = self._targets['main']eval_func = self.eval_func or targetit = copy.copy(iterator)for batch in it:

observation = {}with reporter_module.report_scope(observation):

in_arrays = self.converter(batch, self.device)if isinstance(in_arrays, tuple):

in_vars = tuple(variable.Variable(x, volatile='on')for x in in_arrays)

eval_func(*in_vars)

Page 34: Chainer の Trainer 解説と NStepLSTM について

Evaluator

© 2017 Retrieva, Inc. 34

# examples/mnist/train_mnist.pytest_iter = chainer.iterators.SerialIterator(test, args.batchsize,

repeat=False, shuffle=False)trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

# chainer/training/extensions/evaluator.pyclass Evaluator(extension.Extension):

def evaluate(self):iterator = self._iterators['main']target = self._targets['main']eval_func = self.eval_func or targetit = copy.copy(iterator)for batch in it:

observation = {}with reporter_module.report_scope(observation):

in_arrays = self.converter(batch, self.device)if isinstance(in_arrays, tuple):

in_vars = tuple(variable.Variable(x, volatile='on')for x in in_arrays)

eval_func(*in_vars)

Iterator から全部呼び出す

Page 35: Chainer の Trainer 解説と NStepLSTM について

Evaluator

© 2017 Retrieva, Inc. 35

# examples/mnist/train_mnist.pytest_iter = chainer.iterators.SerialIterator(test, args.batchsize,

repeat=False, shuffle=False)trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

# chainer/training/extensions/evaluator.pyclass Evaluator(extension.Extension):

def evaluate(self):iterator = self._iterators['main']target = self._targets['main']eval_func = self.eval_func or targetit = copy.copy(iterator)for batch in it:

observation = {}with reporter_module.report_scope(observation):

in_arrays = self.converter(batch, self.device)if isinstance(in_arrays, tuple):

in_vars = tuple(variable.Variable(x, volatile='on')for x in in_arrays)

eval_func(*in_vars)

Iterator から全部呼び出す

Converter にかける(変換し、to_gpu する)

Page 36: Chainer の Trainer 解説と NStepLSTM について

Evaluator

© 2017 Retrieva, Inc. 36

# examples/mnist/train_mnist.pytest_iter = chainer.iterators.SerialIterator(test, args.batchsize,

repeat=False, shuffle=False)trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

# chainer/training/extensions/evaluator.pyclass Evaluator(extension.Extension):

def evaluate(self):iterator = self._iterators['main']target = self._targets['main']eval_func = self.eval_func or targetit = copy.copy(iterator)for batch in it:

observation = {}with reporter_module.report_scope(observation):

in_arrays = self.converter(batch, self.device)if isinstance(in_arrays, tuple):

in_vars = tuple(variable.Variable(x, volatile='on')for x in in_arrays)

eval_func(*in_vars)

Iterator から全部呼び出す

Converter にかける(変換し、to_gpu する)

model に forward

Page 37: Chainer の Trainer 解説と NStepLSTM について

説明していないこと• Reporter 周り• Evaluator で、eval_func しているが戻り値を使っていない理由• (ざっくり⾔うと)Classifier 内で、reporter に loss, accuracy を登

録している

• Extension 周り

© 2017 Retrieva, Inc. 37

Page 38: Chainer の Trainer 解説と NStepLSTM について

第2部 NStrepLSTM との格闘

© 2017 Retrieva, Inc. 38

Page 39: Chainer の Trainer 解説と NStepLSTM について

NStepLSTM とは• RNN のための、Chainer 1.16.0 で導⼊された Link• cuDNN の恩恵を受けて、⾼速に動く• 既存の LSTM と使い⽅が違う• 既存の LSTM のサンプルは examples/ptb/train_ptb.py

© 2017 Retrieva, Inc. 39

Page 40: Chainer の Trainer 解説と NStepLSTM について

RNN• Recurrent Neural Network• 並び⽅に意味のある、「系列データ」を扱う場合に⽤いられる• 応⽤例:⽂章の推定、⾳声認識、変動する数値の推定• 例:⽂章が途中まで与えられた時、次の単語を予測する問題

© 2017 Retrieva, Inc. 40

私 は ⽩い ⽝ が ?x1 x2 x3 x4 x5

y1 y2 y3 y4 y5

x1〜x5を⼊⼒データとして、y5を推定する。

Page 41: Chainer の Trainer 解説と NStepLSTM について

LSTM と NStepLSTMデータ1 1 2データ1ラベル A Bデータ2 1 2 3データ2ラベル A B C

© 2017 Retrieva, Inc. 41

• LSTM(逐次渡す)• x1: Variable[1, 1]• t1: Variable[B, B]• x2: Variable[2, 2]• t2: Variable[0, C]

• NStepLSTM(⼀度に渡す)• xs: [Variable[1,2], Variable[1,2,3]]• ts: [Variable[A,B], Variable[A,B,C]]

Page 42: Chainer の Trainer 解説と NStepLSTM について

LSTM と NStepLSTMデータ1 1 2データ1ラベル A Bデータ2 1 2 3データ2ラベル A B C

© 2017 Retrieva, Inc. 42

• LSTM(逐次渡す)• x1: Variable[1, 1]• t1: Variable[B, B]• x2: Variable[2, 2]• t2: Variable[0, C]

• NStepLSTM(⼀度に渡す)• xs: [Variable[1,2], Variable[1,2,3]]• ts: [Variable[A,B], Variable[A,B,C]]

⻑さが合っていない時、0 などで埋める必要がある

⻑さを合わせなくて良い

Variable の list を渡す

Page 43: Chainer の Trainer 解説と NStepLSTM について

NStepLSTM サンプル

© 2017 Retrieva, Inc. 43

class RNNNStepLSTM(chainer.Chain):def __init__(self, n_layer, n_units, train=True):

super(RNNNStepLSTM, self).__init__(l1 = L.NStepLSTM(n_layer, n_units, n_units, 0.5, True),

)self.n_layer = n_layerself.n_units = n_units

def __call__(self, xs):xp = self.xphx = chainer.Variable(xp.zeros(

(self.n_layer, len(xs), self.n_units), dtype=xp.float32))cx = chainer.Variable(xp.zeros(

(self.n_layer, len(xs), self.n_units), dtype=xp.float32))hy, cy, ys = self.l1(hx, cx, xs, train=self.train)

Page 44: Chainer の Trainer 解説と NStepLSTM について

NStepLSTM サンプル

© 2017 Retrieva, Inc. 44

class RNNNStepLSTM(chainer.Chain):def __init__(self, n_layer, n_units, train=True):

super(RNNNStepLSTM, self).__init__(l1 = L.NStepLSTM(n_layer, n_units, n_units, 0.5, True),

)self.n_layer = n_layerself.n_units = n_units

def __call__(self, xs):xp = self.xphx = chainer.Variable(xp.zeros(

(self.n_layer, len(xs), self.n_units), dtype=xp.float32))cx = chainer.Variable(xp.zeros(

(self.n_layer, len(xs), self.n_units), dtype=xp.float32))hy, cy, ys = self.l1(hx, cx, xs, train=self.train)

レイヤー数、ユニット数、Dropout を指定するパラメータの

初期状態を作成し、渡す

Variable のリストを⼊⼒する

出⼒も Variable のリスト

Page 45: Chainer の Trainer 解説と NStepLSTM について

NStepLSTM と、標準的な Trainer の齟齬• 標準的な Trainer の構成では、model には Variable を渡す• NStepLSTM では、「Variable のリスト」を渡さなければいけない

© 2017 Retrieva, Inc. 45

# chainer/training/updater.pyclass StandardUpdater(Updater):

def update_core(self):batch = self._iterators['main'].next()in_arrays = self.converter(batch, self.device)...if isinstance(in_arrays, tuple):

in_vars = tuple(variable.Variable(x) for x in in_arrays)optimizer.update(loss_func, *in_vars)

Variableを作成し、そのまま渡している

Page 46: Chainer の Trainer 解説と NStepLSTM について

NStepLSTM と、標準的な Trainer の齟齬• loss の計算を、既存のメソッドに任せられない

© 2017 Retrieva, Inc. 46

# chainer/links/model/classifier.pyclass Classifier(link.Chain):

def __init__(self, predictor,lossfun=softmax_cross_entropy.softmax_cross_entropy,accfun=accuracy.accuracy):

def __call__(self, *args):self.y = self.predictor(*x)self.loss = self.lossfun(self.y, t)reporter.report({'loss': self.loss}, self)if self.compute_accuracy:

self.accuracy = self.accfun(self.y, t)reporter.report({'accuracy': self.accuracy}, self)

return self.loss

NStepLSTM の出⼒だと、y は Variable のリスト

accuracy も同様

Page 47: Chainer の Trainer 解説と NStepLSTM について

ptb/train_ptb.py を NStepLSTM で• https://github.com/kei-s/chainer-ptb-

nsteplstm/blob/master/train_ptb_nstep.py• できるだけ構成を同じにして(Trainer の上で)、

train_ptb.py を NStepLSTM を使って実装してみる

• Disclaimer• testモード(100件)では完⾛したけど、全データは⾛らせていない• 無駄なコードはありそう…

© 2017 Retrieva, Inc. 47

Page 48: Chainer の Trainer 解説と NStepLSTM について

ptb/train_ptb.py を NStepLSTM で• モデル

• EmbedID にかけるため、⼀度 concat し、split_axis でまた分ける• 系列のそれぞれの要素を Linear にかける

• Iterator• bprop_len を Iterator に渡し、バッチで系列化したものを返す

• Converter• NStepLSTM を使った seq2seq https://github.com/pfnet/chainer/pull/2070 を参照

• Updater• 系列をそのままモデルに渡すように変更

• Evaluator• 系列をそのままモデルに渡すように変更

• Lossfun• softmax_cross_entropy をそれぞれの系列に対してかけ、⾜し合わせる

© 2017 Retrieva, Inc. 48

Page 49: Chainer の Trainer 解説と NStepLSTM について

ptb/train_ptb.py を NStepLSTM で• ⾼速化したか?• Estimated time では 6時間 → 3時間• とはいえ、実際に⾛らせると 3時間以上かかりそう

• Estimated Time に Evaluator 部分が考慮されてなさそう

• ⾼速化のために必要なこと• Chainer もしくは Numpy の世界で処理を終わらせること• Python の世界でループを回すとかなり遅くなる

• Lossfun でループを回しているが、concat して渡してもよさそう(互換性が微妙な予感)

© 2017 Retrieva, Inc. 49

Page 50: Chainer の Trainer 解説と NStepLSTM について

© 2017 Retrieva, Inc.