chainer の trainer 解説と nsteplstm について
TRANSCRIPT
Chainer の Trainer 解説とNStepLSTM について株式会社レトリバ
© 2017 Retrieva, Inc.
⾃⼰紹介• ⽩⼟慧(シラツチ ケイ)• 株式会社レトリバ• 2016年4⽉⼊社
• Ruby on Rails / JavaScript• フロントエンド側の⼈間
• ⼤学時代は複雑ネットワーク科学の研究• Chainer ⼊⾨中
© 2017 Retrieva, Inc. 2
アジェンダ• 第1部 Chainer における Trainer の解説• 第2部 NStepLSTM との格闘
© 2017 Retrieva, Inc. 3
アンケート• Chainer を使っている⽅• Chainer の Trainer を使っている⽅• LSTM を使っている⽅• NStepLSTM を使っている⽅• NStepLSTM と Trainer を使っている⽅
© 2017 Retrieva, Inc. 4
第1部 Chainer における Trainer
© 2017 Retrieva, Inc. 5
Chainer における Trainer• Chainer 1.11.0 から導⼊された学習フレームワーク• batchの取り出し、forward/backward が抽象化されている• 進捗表⽰、モデルのスナップショットなど
• Trainer 後から⼊⾨した⼈(私も)は、MNIST のサンプルがTrainerで抽象化されていて、何が起きているのかわからない• 以前から Chainer を使っている⼈は、Trainer なしで動かして
いることが多い
© 2017 Retrieva, Inc. 6
Trainer 全体図 (train_mnist.py)
© 2017 Retrieva, Inc. 7
Trainer
Updater (StandardUpdater)
Iterator
Optimizer
Classifier
model (MLP)
train dataset
EvaluatorIterator
test dataset
Extensions
• dump_graph, snapshot• LogReport, PrintReport• ProgressBar
• converter• loss_func• device
• converter• device
Trainer• Trainer フレームワークの⼤元• 渡された Updater、(必要があれば) Evaluator を実⾏する• グラフのダンプ、スナップショット、レポーティング、進捗表
⽰などを、Extension として実⾏できる
© 2017 Retrieva, Inc. 8
Trainer
• 指定した epoch 数になるまで、Updater の update() を呼ぶ
© 2017 Retrieva, Inc. 9
# examples/mnist/train_mnist.pytrainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
# chainer/training/trainer.pyclass Trainer(object):
def run(self):update = self.updater.update# main training looptry:
while not stop_trigger(self):self.observation = {}with reporter.scope(self.observation):
update()
Updater
© 2017 Retrieva, Inc. 10
Trainer
Updater (StandardUpdater)
Iterator
Optimizer
Classifier
model (MLP)
train dataset
EvaluatorIterator
test dataset
Extensions
• dump_graph, snapshot• LogReport, PrintReport• ProgressBar
• converter• loss_func• device
• converter• device
Updater• ⼊⼒を逐次実⾏する• ⼊⼒の Iterator と、Optimizer を持つ• Iterator から⼀つずつデータを読み込み、変換し、Optimizer に
かける
© 2017 Retrieva, Inc. 11
Updater
© 2017 Retrieva, Inc. 12
# examples/mnist/train_mnist.pyupdater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
# chainer/training/updater.pyclass StandardUpdater(Updater):
def update_core(self):batch = self._iterators['main'].next()in_arrays = self.converter(batch, self.device)
optimizer = self._optimizers['main']loss_func = self.loss_func or optimizer.target
if isinstance(in_arrays, tuple):in_vars = tuple(variable.Variable(x) for x in in_arrays)optimizer.update(loss_func, *in_vars)
Updater
© 2017 Retrieva, Inc. 13
# examples/mnist/train_mnist.pyupdater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
# chainer/training/updater.pyclass StandardUpdater(Updater):
def update_core(self):batch = self._iterators['main'].next()in_arrays = self.converter(batch, self.device)
optimizer = self._optimizers['main']loss_func = self.loss_func or optimizer.target
if isinstance(in_arrays, tuple):in_vars = tuple(variable.Variable(x) for x in in_arrays)optimizer.update(loss_func, *in_vars)
Iterator から⼀つ呼び出す
Updater
© 2017 Retrieva, Inc. 14
# examples/mnist/train_mnist.pyupdater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
# chainer/training/updater.pyclass StandardUpdater(Updater):
def update_core(self):batch = self._iterators['main'].next()in_arrays = self.converter(batch, self.device)
optimizer = self._optimizers['main']loss_func = self.loss_func or optimizer.target
if isinstance(in_arrays, tuple):in_vars = tuple(variable.Variable(x) for x in in_arrays)optimizer.update(loss_func, *in_vars)
Iterator から⼀つ呼び出す
Converter にかける(変換し、to_gpu する)
Updater
© 2017 Retrieva, Inc. 15
# examples/mnist/train_mnist.pyupdater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
# chainer/training/updater.pyclass StandardUpdater(Updater):
def update_core(self):batch = self._iterators['main'].next()in_arrays = self.converter(batch, self.device)
optimizer = self._optimizers['main']loss_func = self.loss_func or optimizer.target
if isinstance(in_arrays, tuple):in_vars = tuple(variable.Variable(x) for x in in_arrays)optimizer.update(loss_func, *in_vars)
Iterator から⼀つ呼び出す
Converter にかける(変換し、to_gpu する)
Optimizer のupdate を呼ぶ
Iterator
© 2017 Retrieva, Inc. 16
Trainer
Updater (StandardUpdater)
Iterator
Optimizer
Classifier
model (MLP)
train dataset
EvaluatorIterator
test dataset
Extensions
• dump_graph, snapshot• LogReport, PrintReport• ProgressBar
• converter• loss_func• device
• converter• device
Iterator
© 2017 Retrieva, Inc. 17
# chainer/iterators/serial_iterator.pyclass SerialIterator(iterator.Iterator):
def __next__(self):...return batch
@propertydef epoch_detail(self):
return self.epoch + self.current_position / len(self.dataset)
# examples/mnist/train_mnist.pytrain, test = chainer.datasets.get_mnist()train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
• Iterator として、batch を返す• 回している回数の管理をする
Optimizer
© 2017 Retrieva, Inc. 18
Trainer
Updater (StandardUpdater)
Iterator
Optimizer
Classifier
model (MLP)
train dataset
EvaluatorIterator
test dataset
Extensions
• dump_graph, snapshot• LogReport, PrintReport• ProgressBar
• converter• loss_func• device
• converter• device
Optimizer• ⼊⼒データを model に forward し、返り値の loss を
backward する• 最適化アルゴリズムごとに実装がある• SGD, MomentumSGD, Adam, …• Optimizer で抽象化されている
© 2017 Retrieva, Inc. 19
Optimizer
© 2017 Retrieva, Inc. 20
# chainer/training/updater.pyloss_func = self.loss_func or optimizer.targetoptimizer.update(loss_func, *in_vars)
# examples/mnist/train_mnist.pyoptimizer = chainer.optimizers.Adam()optimizer.setup(model)
# chainer/optimizer.pyclass GradientMethod(Optimizer):
def update(self, lossfun=None, *args, **kwds):if lossfun is not None:
use_cleargrads = getattr(self, '_use_cleargrads', False)loss = lossfun(*args, **kwds)if use_cleargrads:
self.target.cleargrads()else:
self.target.zerograds()loss.backward()
Optimizer
© 2017 Retrieva, Inc. 21
# chainer/training/updater.pyloss_func = self.loss_func or optimizer.targetoptimizer.update(loss_func, *in_vars)
# examples/mnist/train_mnist.pyoptimizer = chainer.optimizers.Adam()optimizer.setup(model)
# chainer/optimizer.pyclass GradientMethod(Optimizer):
def update(self, lossfun=None, *args, **kwds):if lossfun is not None:
use_cleargrads = getattr(self, '_use_cleargrads', False)loss = lossfun(*args, **kwds)if use_cleargrads:
self.target.cleargrads()else:
self.target.zerograds()loss.backward()
target は渡された model
(Classifier)
Optimizer
© 2017 Retrieva, Inc. 22
# chainer/training/updater.pyloss_func = self.loss_func or optimizer.targetoptimizer.update(loss_func, *in_vars)
# examples/mnist/train_mnist.pyoptimizer = chainer.optimizers.Adam()optimizer.setup(model)
# chainer/optimizer.pyclass GradientMethod(Optimizer):
def update(self, lossfun=None, *args, **kwds):if lossfun is not None:
use_cleargrads = getattr(self, '_use_cleargrads', False)loss = lossfun(*args, **kwds)if use_cleargrads:
self.target.cleargrads()else:
self.target.zerograds()loss.backward()
target は渡された model
(Classifier)
model に forward
Optimizer
© 2017 Retrieva, Inc. 23
# chainer/training/updater.pyloss_func = self.loss_func or optimizer.targetoptimizer.update(loss_func, *in_vars)
# examples/mnist/train_mnist.pyoptimizer = chainer.optimizers.Adam()optimizer.setup(model)
# chainer/optimizer.pyclass GradientMethod(Optimizer):
def update(self, lossfun=None, *args, **kwds):if lossfun is not None:
use_cleargrads = getattr(self, '_use_cleargrads', False)loss = lossfun(*args, **kwds)if use_cleargrads:
self.target.cleargrads()else:
self.target.zerograds()loss.backward()
target は渡された model
(Classifier)
model に forward
backward を実⾏
Classifier
© 2017 Retrieva, Inc. 24
Trainer
Updater (StandardUpdater)
Iterator
Optimizer
Classifier
model (MLP)
train dataset
EvaluatorIterator
test dataset
Extensions
• dump_graph, snapshot• LogReport, PrintReport• ProgressBar
• converter• loss_func• device
• converter• device
Classifier• 教師あり学習⽤の model のラッパー• ⼊⼒と正解データから、loss と accuracy を計算する
© 2017 Retrieva, Inc. 25
Classifier
© 2017 Retrieva, Inc. 26
# examples/mnist/train_mnist.pymodel = L.Classifier(MLP(args.unit, 10))
# chainer/links/model/classifier.pyclass Classifier(link.Chain):
def __init__(self, predictor,lossfun=softmax_cross_entropy.softmax_cross_entropy,accfun=accuracy.accuracy):
def __call__(self, *args):self.y = self.predictor(*x)self.loss = self.lossfun(self.y, t)reporter.report({'loss': self.loss}, self)if self.compute_accuracy:
self.accuracy = self.accfun(self.y, t)reporter.report({'accuracy': self.accuracy}, self)
return self.loss
Classifier
© 2017 Retrieva, Inc. 27
# examples/mnist/train_mnist.pymodel = L.Classifier(MLP(args.unit, 10))
# chainer/links/model/classifier.pyclass Classifier(link.Chain):
def __init__(self, predictor,lossfun=softmax_cross_entropy.softmax_cross_entropy,accfun=accuracy.accuracy):
def __call__(self, *args):self.y = self.predictor(*x)self.loss = self.lossfun(self.y, t)reporter.report({'loss': self.loss}, self)if self.compute_accuracy:
self.accuracy = self.accfun(self.y, t)reporter.report({'accuracy': self.accuracy}, self)
return self.loss
損失関数を指定する
Classifier
© 2017 Retrieva, Inc. 28
# examples/mnist/train_mnist.pymodel = L.Classifier(MLP(args.unit, 10))
# chainer/links/model/classifier.pyclass Classifier(link.Chain):
def __init__(self, predictor,lossfun=softmax_cross_entropy.softmax_cross_entropy,accfun=accuracy.accuracy):
def __call__(self, *args):self.y = self.predictor(*x)self.loss = self.lossfun(self.y, t)reporter.report({'loss': self.loss}, self)if self.compute_accuracy:
self.accuracy = self.accfun(self.y, t)reporter.report({'accuracy': self.accuracy}, self)
return self.loss
損失関数を指定する
model に forward
Classifier
© 2017 Retrieva, Inc. 29
# examples/mnist/train_mnist.pymodel = L.Classifier(MLP(args.unit, 10))
# chainer/links/model/classifier.pyclass Classifier(link.Chain):
def __init__(self, predictor,lossfun=softmax_cross_entropy.softmax_cross_entropy,accfun=accuracy.accuracy):
def __call__(self, *args):self.y = self.predictor(*x)self.loss = self.lossfun(self.y, t)reporter.report({'loss': self.loss}, self)if self.compute_accuracy:
self.accuracy = self.accfun(self.y, t)reporter.report({'accuracy': self.accuracy}, self)
return self.loss
損失関数を指定する
model に forward
loss を算出
accuracy を算出
Classifier
© 2017 Retrieva, Inc. 30
# examples/mnist/train_mnist.pymodel = L.Classifier(MLP(args.unit, 10))
# chainer/links/model/classifier.pyclass Classifier(link.Chain):
def __init__(self, predictor,lossfun=softmax_cross_entropy.softmax_cross_entropy,accfun=accuracy.accuracy):
def __call__(self, *args):self.y = self.predictor(*x)self.loss = self.lossfun(self.y, t)reporter.report({'loss': self.loss}, self)if self.compute_accuracy:
self.accuracy = self.accfun(self.y, t)reporter.report({'accuracy': self.accuracy}, self)
return self.loss
損失関数を指定する
model に forward
loss を算出
accuracy を算出
loss を返す
Evaluator
© 2017 Retrieva, Inc. 31
Trainer
Updater (StandardUpdater)
Iterator
Optimizer
Classifier
model (MLP)
train dataset
EvaluatorIterator
test dataset
Extensions
• dump_graph, snapshot• LogReport, PrintReport• ProgressBar
• converter• loss_func• device
• converter• device
Evaluator• テストデータに対して、loss, accuracy などを計算し、検証す
る• epoch ごとに、現在まで学習された model に対して検証する• ⼤まかには、Updater と対応している
© 2017 Retrieva, Inc. 32
Evaluator
© 2017 Retrieva, Inc. 33
# examples/mnist/train_mnist.pytest_iter = chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False)trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
# chainer/training/extensions/evaluator.pyclass Evaluator(extension.Extension):
def evaluate(self):iterator = self._iterators['main']target = self._targets['main']eval_func = self.eval_func or targetit = copy.copy(iterator)for batch in it:
observation = {}with reporter_module.report_scope(observation):
in_arrays = self.converter(batch, self.device)if isinstance(in_arrays, tuple):
in_vars = tuple(variable.Variable(x, volatile='on')for x in in_arrays)
eval_func(*in_vars)
Evaluator
© 2017 Retrieva, Inc. 34
# examples/mnist/train_mnist.pytest_iter = chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False)trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
# chainer/training/extensions/evaluator.pyclass Evaluator(extension.Extension):
def evaluate(self):iterator = self._iterators['main']target = self._targets['main']eval_func = self.eval_func or targetit = copy.copy(iterator)for batch in it:
observation = {}with reporter_module.report_scope(observation):
in_arrays = self.converter(batch, self.device)if isinstance(in_arrays, tuple):
in_vars = tuple(variable.Variable(x, volatile='on')for x in in_arrays)
eval_func(*in_vars)
Iterator から全部呼び出す
Evaluator
© 2017 Retrieva, Inc. 35
# examples/mnist/train_mnist.pytest_iter = chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False)trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
# chainer/training/extensions/evaluator.pyclass Evaluator(extension.Extension):
def evaluate(self):iterator = self._iterators['main']target = self._targets['main']eval_func = self.eval_func or targetit = copy.copy(iterator)for batch in it:
observation = {}with reporter_module.report_scope(observation):
in_arrays = self.converter(batch, self.device)if isinstance(in_arrays, tuple):
in_vars = tuple(variable.Variable(x, volatile='on')for x in in_arrays)
eval_func(*in_vars)
Iterator から全部呼び出す
Converter にかける(変換し、to_gpu する)
Evaluator
© 2017 Retrieva, Inc. 36
# examples/mnist/train_mnist.pytest_iter = chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False)trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
# chainer/training/extensions/evaluator.pyclass Evaluator(extension.Extension):
def evaluate(self):iterator = self._iterators['main']target = self._targets['main']eval_func = self.eval_func or targetit = copy.copy(iterator)for batch in it:
observation = {}with reporter_module.report_scope(observation):
in_arrays = self.converter(batch, self.device)if isinstance(in_arrays, tuple):
in_vars = tuple(variable.Variable(x, volatile='on')for x in in_arrays)
eval_func(*in_vars)
Iterator から全部呼び出す
Converter にかける(変換し、to_gpu する)
model に forward
説明していないこと• Reporter 周り• Evaluator で、eval_func しているが戻り値を使っていない理由• (ざっくり⾔うと)Classifier 内で、reporter に loss, accuracy を登
録している
• Extension 周り
© 2017 Retrieva, Inc. 37
第2部 NStrepLSTM との格闘
© 2017 Retrieva, Inc. 38
NStepLSTM とは• RNN のための、Chainer 1.16.0 で導⼊された Link• cuDNN の恩恵を受けて、⾼速に動く• 既存の LSTM と使い⽅が違う• 既存の LSTM のサンプルは examples/ptb/train_ptb.py
© 2017 Retrieva, Inc. 39
RNN• Recurrent Neural Network• 並び⽅に意味のある、「系列データ」を扱う場合に⽤いられる• 応⽤例:⽂章の推定、⾳声認識、変動する数値の推定• 例:⽂章が途中まで与えられた時、次の単語を予測する問題
© 2017 Retrieva, Inc. 40
私 は ⽩い ⽝ が ?x1 x2 x3 x4 x5
y1 y2 y3 y4 y5
x1〜x5を⼊⼒データとして、y5を推定する。
LSTM と NStepLSTMデータ1 1 2データ1ラベル A Bデータ2 1 2 3データ2ラベル A B C
© 2017 Retrieva, Inc. 41
• LSTM(逐次渡す)• x1: Variable[1, 1]• t1: Variable[B, B]• x2: Variable[2, 2]• t2: Variable[0, C]
• NStepLSTM(⼀度に渡す)• xs: [Variable[1,2], Variable[1,2,3]]• ts: [Variable[A,B], Variable[A,B,C]]
LSTM と NStepLSTMデータ1 1 2データ1ラベル A Bデータ2 1 2 3データ2ラベル A B C
© 2017 Retrieva, Inc. 42
• LSTM(逐次渡す)• x1: Variable[1, 1]• t1: Variable[B, B]• x2: Variable[2, 2]• t2: Variable[0, C]
• NStepLSTM(⼀度に渡す)• xs: [Variable[1,2], Variable[1,2,3]]• ts: [Variable[A,B], Variable[A,B,C]]
⻑さが合っていない時、0 などで埋める必要がある
⻑さを合わせなくて良い
Variable の list を渡す
NStepLSTM サンプル
© 2017 Retrieva, Inc. 43
class RNNNStepLSTM(chainer.Chain):def __init__(self, n_layer, n_units, train=True):
super(RNNNStepLSTM, self).__init__(l1 = L.NStepLSTM(n_layer, n_units, n_units, 0.5, True),
)self.n_layer = n_layerself.n_units = n_units
def __call__(self, xs):xp = self.xphx = chainer.Variable(xp.zeros(
(self.n_layer, len(xs), self.n_units), dtype=xp.float32))cx = chainer.Variable(xp.zeros(
(self.n_layer, len(xs), self.n_units), dtype=xp.float32))hy, cy, ys = self.l1(hx, cx, xs, train=self.train)
NStepLSTM サンプル
© 2017 Retrieva, Inc. 44
class RNNNStepLSTM(chainer.Chain):def __init__(self, n_layer, n_units, train=True):
super(RNNNStepLSTM, self).__init__(l1 = L.NStepLSTM(n_layer, n_units, n_units, 0.5, True),
)self.n_layer = n_layerself.n_units = n_units
def __call__(self, xs):xp = self.xphx = chainer.Variable(xp.zeros(
(self.n_layer, len(xs), self.n_units), dtype=xp.float32))cx = chainer.Variable(xp.zeros(
(self.n_layer, len(xs), self.n_units), dtype=xp.float32))hy, cy, ys = self.l1(hx, cx, xs, train=self.train)
レイヤー数、ユニット数、Dropout を指定するパラメータの
初期状態を作成し、渡す
Variable のリストを⼊⼒する
出⼒も Variable のリスト
NStepLSTM と、標準的な Trainer の齟齬• 標準的な Trainer の構成では、model には Variable を渡す• NStepLSTM では、「Variable のリスト」を渡さなければいけない
© 2017 Retrieva, Inc. 45
# chainer/training/updater.pyclass StandardUpdater(Updater):
def update_core(self):batch = self._iterators['main'].next()in_arrays = self.converter(batch, self.device)...if isinstance(in_arrays, tuple):
in_vars = tuple(variable.Variable(x) for x in in_arrays)optimizer.update(loss_func, *in_vars)
Variableを作成し、そのまま渡している
NStepLSTM と、標準的な Trainer の齟齬• loss の計算を、既存のメソッドに任せられない
© 2017 Retrieva, Inc. 46
# chainer/links/model/classifier.pyclass Classifier(link.Chain):
def __init__(self, predictor,lossfun=softmax_cross_entropy.softmax_cross_entropy,accfun=accuracy.accuracy):
def __call__(self, *args):self.y = self.predictor(*x)self.loss = self.lossfun(self.y, t)reporter.report({'loss': self.loss}, self)if self.compute_accuracy:
self.accuracy = self.accfun(self.y, t)reporter.report({'accuracy': self.accuracy}, self)
return self.loss
NStepLSTM の出⼒だと、y は Variable のリスト
accuracy も同様
ptb/train_ptb.py を NStepLSTM で• https://github.com/kei-s/chainer-ptb-
nsteplstm/blob/master/train_ptb_nstep.py• できるだけ構成を同じにして(Trainer の上で)、
train_ptb.py を NStepLSTM を使って実装してみる
• Disclaimer• testモード(100件)では完⾛したけど、全データは⾛らせていない• 無駄なコードはありそう…
© 2017 Retrieva, Inc. 47
ptb/train_ptb.py を NStepLSTM で• モデル
• EmbedID にかけるため、⼀度 concat し、split_axis でまた分ける• 系列のそれぞれの要素を Linear にかける
• Iterator• bprop_len を Iterator に渡し、バッチで系列化したものを返す
• Converter• NStepLSTM を使った seq2seq https://github.com/pfnet/chainer/pull/2070 を参照
• Updater• 系列をそのままモデルに渡すように変更
• Evaluator• 系列をそのままモデルに渡すように変更
• Lossfun• softmax_cross_entropy をそれぞれの系列に対してかけ、⾜し合わせる
© 2017 Retrieva, Inc. 48
ptb/train_ptb.py を NStepLSTM で• ⾼速化したか?• Estimated time では 6時間 → 3時間• とはいえ、実際に⾛らせると 3時間以上かかりそう
• Estimated Time に Evaluator 部分が考慮されてなさそう
• ⾼速化のために必要なこと• Chainer もしくは Numpy の世界で処理を終わらせること• Python の世界でループを回すとかなり遅くなる
• Lossfun でループを回しているが、concat して渡してもよさそう(互換性が微妙な予感)
© 2017 Retrieva, Inc. 49
© 2017 Retrieva, Inc.