Tensorflow Estimator API
Databricks 無料トライアル
Tensorflow Estimator API とは
Estimator は、完全なモデルを表しますが、ユーザーの多くに複雑な印象を与える傾向があります。Estimator API とは、モデルを訓練して、その精度を評価し、推論を作成するためのメソッドを提供する高レベル API です。下の図のように、TensorFlow は複数の API 層からなるプログラミングスタックを提供します。Estimator には、事 前構築された Estimator と、独自でカスタマイズする Estimator の 2 つのタイプがあります。Estimator ベースのモデルは、モデルを変更せずに、ローカルホストまたは分散マルチサーバー環境で実行できます。また、モデルを再コーディングせずに、Estimator ベースのモデルを CPU、GPU または TPU で実行することも可能です。
Estimator の 4 つの主要な機能
- 訓練:与えられた入力データをもとに、指定された回数だけモデルを訓練
- 評価:テストデータセットに基づいたモデルの評価
- 推論:訓練されたモデルを使用したインターフェースの実行
- エクスポート:サーブするためのモデルのエクスポート
さらに、Estimator には、チェックポイントの保存や復元、サマリーの作成などの訓練ジョブに共通するデフォルト動作があります。Estimator では、TensorFlow グラフのモデルと入力部分に対応するモデル関数(model_fn)と入力関数(input_fn)を記述する必要があります。
Estimator を使用するメリット
- モデル開発者間での、実装共有を簡素化
- モデルの作成が必要な場合、低レベルの TensorFlow API と比べ、操作が容易であることから、高レベルの直感的なコードで優れたモデルの開発が可能
- tf.keras.layers 上に構築されているため、カスタマイズが簡単
- Estimator でのグラフ作成による作業負担の軽減
- 安全な分散型トレーニングループを提供し、以下の方法とタイミングを管理:
- グラ フの作成
- 変数の初期化
- データの読み込み
- 例外処理への対応
- チェックポイントファイルの作成および障害からの復旧
- TensorBoard サマリーの保存