私はこれとまったく同じ問題に遭遇しました、そして人はそれがウサギの穴でした。誰かが1日の作業を節約できる可能性があるため、ここに私のソリューションを投稿したかった:
TensorFlowスレッド固有のデータ構造
TensorFlowには、model.predict
を呼び出すときに舞台裏で機能する2つの主要なデータ構造があります。 (またはkeras.models.load_model
、またはkeras.backend.clear_session
、またはTensorFlowバックエンドと相互作用する他のほとんどすべての関数):
- Kerasモデルの構造を表すTensorFlowグラフ
- 現在のグラフとTensorFlowランタイム間の接続であるTensorFlowセッション
いくつかの掘り下げなしにドキュメントで明確に明確にされていないことは、セッションとグラフの両方が現在のスレッドのプロパティであるということです。 。こことここのAPIドキュメントを参照してください。
異なるスレッドでのTensorFlowモデルの使用
モデルを一度ロードしてから.predict()
を呼び出すのは自然なことです。 後で何度もそれについて:
from keras.models import load_model
MY_MODEL = load_model('path/to/model/file')
def some_worker_function(inputs):
return MY_MODEL.predict(inputs)
CeleryのようなWebサーバーまたはワーカープールのコンテキストでは、これは、load_model
を含むモジュールをインポートするときにモデルをロードすることを意味します。 行の場合、別のスレッドがsome_worker_function
を実行します 、Kerasモデルを含むグローバル変数でpredictを実行します。ただし、別のスレッドにロードされたモデルでpredictを実行しようとすると、「テンソルはこのグラフの要素ではありません」というエラーが発生します。 ValueErrorなど、このトピックに触れたいくつかのSO投稿のおかげで、Tensor Tensor(...)はこのグラフの要素ではありません。グローバル変数kerasモデルを使用する場合。これを機能させるには、使用されたTensorFlowグラフを使用する必要があります。前に見たように、グラフは現在のスレッドのプロパティです。更新されたコードは次のようになります:
from keras.models import load_model
import tensorflow as tf
MY_MODEL = load_model('path/to/model/file')
MY_GRAPH = tf.get_default_graph()
def some_worker_function(inputs):
with MY_GRAPH.as_default():
return MY_MODEL.predict(inputs)
ここでやや意外なひねりは次のとおりです。Thread
を使用している場合は、上記のコードで十分です。 sですが、Process
を使用している場合は無期限にハングします es。 また、デフォルトでは、Celeryはプロセスを使用してすべてのワーカープールを管理します。したがって、この時点では、物事はまだ セロリに取り組んでいません。
これがThread
でのみ機能するのはなぜですか s?
Pythonでは、Thread
■親プロセスと同じグローバル実行コンテキストを共有します。 Python _threadドキュメントから:
このモジュールは、複数のスレッド(軽量プロセスまたはタスクとも呼ばれます)(グローバルデータスペースを共有する複数の制御スレッド)を操作するための低レベルのプリミティブを提供します。
スレッドは実際の個別のプロセスではないため、同じPythonインタープリターを使用し、悪名高いGlobal Interpeter Lock(GIL)の対象となります。おそらくこの調査にとってもっと重要なのは、彼らが共有することです。 親とのグローバルデータスペース。
これとは対照的に、Process
esは実際です プログラムによって生成された新しいプロセス。これは次のことを意味します:
- 新しいPythonインタープリターインスタンス(GILなし)
- グローバルアドレス空間は重複しています
ここで違いに注意してください。 Thread
の間 ■共有された単一のグローバルセッション変数(tensorflow_backend
に内部的に保存されている)にアクセスできます Kerasのモジュール)、Process
esにはSession変数の重複があります。
この問題を最もよく理解しているのは、Session変数がクライアント(プロセス)とTensorFlowランタイム間の一意の接続を表すことになっているが、フォークプロセスで複製されるため、この接続情報が適切に調整されないことです。これにより、別のプロセスで作成されたセッションを使用しようとすると、TensorFlowがハングします。 TensorFlowの内部でこれがどのように機能しているかについて誰かがもっと洞察を持っているなら、私はそれを聞いてみたいです!
解決策/回避策
Thread
を使用するようにCeleryを調整しました s Process
の代わりに プーリング用のes。このアプローチにはいくつかの欠点がありますが(上記のGILコメントを参照)、これによりモデルを1回だけロードできます。 TensorFlowランタイムはすべてのCPUコアを最大化するため、実際にはCPUにバインドされていません(Pythonで記述されていないためGILを回避できます)。スレッドベースのプーリングを行うには、Celeryに別のライブラリを提供する必要があります。ドキュメントは2つのオプションを提案しています:gevent
またはeventlet
。次に、選択したライブラリを--pool
を介してワーカーに渡します。 コマンドライン引数。
あるいは、(すでに@ pX0rを知っているように)Theanoなどの他のKerasバックエンドにはこの問題がないようです。これらの問題はTensorFlowの実装の詳細と密接に関連しているため、これは理にかなっています。個人的にはまだTheanoを試したことがないので、マイレージは異なる場合があります。
この質問が少し前に投稿されたことは知っていますが、問題はまだ残っているので、これが誰かに役立つことを願っています!