メインコンテンツへジャンプ

Original Blog : Synthetic Data for Better Machine Learning

翻訳: junichi.maruyama 

 

この1年で最も話題になった、ChatGPTDALL-Eのような生成AIの進化を試したことがある人も多いでしょう。これらのツールは、複雑なデータを消費し、より多くのデータを生成することで、驚くほど知的なもののように感じられるのです。これらやその他の新しいアイデア(diffusion modelsgenerative adversarial networks、GAN)は、遊んでみると楽しく、恐ろしいとさえ感じます。

しかし、日常的な機械学習のタスクは、表形式のデータと「通常の」データサイエンス・ツールを使って、売上予測や顧客離れを予測することなどであり、ボッシュが火星の生物をどう描いたかを想像することではない。

A still life on Mars in the style of Hieronymus Bosch, from DALL-E 2
ヒエロニムス・ボス風の火星の生物画(『DALL-E 2』より

生成AIが、例えば単純な回帰の問題に役立つとしたらどうだろう。あなたが持っている実際のビジネスデータのような合成データを生成することができるアイデアの関連クラスがあります。合成データは、広く考えられているジェネレーティブAIの重要なアプリケーションです。

このブログでは、典型的な機械学習プロセスにおける合成データのいくつかの使用法を検討します。その回帰問題をどのように支援できるのか、あるいは機密データの取り扱いに関する運用上の懸念に役立つのか。合成データのモデリングにはオープンソースのライブラリSDV (Synthetic Data Vault) を使用し、合成データの生成にはMLflow, Apache Spark , Deltaを使用して管理し、最後にDatabricks Auto MLで回帰問題にどのように影響するかを探ります。

なぜ機械学習に合成データなのか?

作りかけのデータは、現実世界を学ぶのに何の役に立つのか?適当に作ったデータは役に立ちません。しかし、現実のデータによく似たデータであれば、役に立つかもしれません。

まず、誰もがより多くのデータを求めています。それは、より良い機械学習モデルを意味する(こともあるからです)。機械学習は現実の世界をモデル化するので、より多くのデータがあれば、その世界の全体像を把握することができます。角のあるケースで何が起こるのか、何が異常なのか、何が繰り返し観察されるのか。本物のデータは入手困難ですが、本物そっくりのデータを無限に入手するのは簡単です。

しかし、合成データは、実際に存在するデータを模倣することしかできない。合成データは、実際のデータを模倣するだけであり、実際のデータセットにはない新たな機微を明らかにすることはできない。しかし、実際のデータが意味することを外挿することは可能であり、場合によっては有益であることもあります。

第二に、データは自由に共有できない場合があります。個人を特定できる機密情報(PII)が含まれている可能性があります。新しいチームとデータを共有し、探査や分析作業を迅速に行うことが望ましいかもしれませんが、共有するためには、長時間の再編集、特別な取り扱い、フォームへの記入など、官僚的な作業が必要になる可能性があります。

合成データは、機密データのようでありながら、実際のデータではないデータを共有することで、その中間的な役割を果たします。場合によっては、この方法にも問題があるかもしれません。合成データが実際のデータポイントに少し似ているとしたら、どうでしょうか。また、不十分な場合もあります。

しかし、合成データを共有することで十分なデータセキュリティを確保しつつ、コラボレーションを加速させることができるユースケースはたくさんある。例えば、新しい問題を解決する信頼性の高い機械学習パイプラインを開発するために、請負業者のチームに協力してもらいたいと考えています。しかし、機密性の高いデータセットを彼らと共有することはできません。合成データを共有すれば、彼らが実際のデータで実行してもうまく機能するパイプラインを構築するのに十分すぎるほど役立つでしょう。

問題:ビッグティッパー

このブログでは、よく知られているNYCタクシーのデータセットを使って説明します。Databricksでは、/databricks-datasets/nyctaxi/tables/nyctaxi_yellowで利用可能です。これは、10年以上にわたってニューヨークでタクシーに乗ったときの基本的な情報を記録したもので、ピックアップとドロップオフポイント、距離、料金、通行料、チップなどが含まれています。何十億行もある大きなもので、このサンプルはこのように始まるサンプルで動作します:

Big Tippers

ここでは、ライダーが旅行の終わりに追加するチップを予測することが問題になります。タクシー内の支払いシステムは、チップの額を機転を利かせて提案したいのかもしれない。その場合、高すぎる、あるいは低すぎる金額を提案しないようにするのが得策である。

これはよくある回帰問題である。しかし、さまざまな理由から、このデータが機密であると考えられているとしよう。請負業者やデータサイエンス・チームと共有できればいいのですが、そうすると、あらゆる法的手続きを踏むことになります。このデータを共有せずに、どうやって正確なモデルを作ることができるでしょうか?

生データを共有するのではなく、その合成版を共有することを試してみてください。

数分で合成データ

SDV は、データを合成するためのPythonライブラリです。テーブル内のデータ、複数のリレーショナルテーブルのデータ、時系列データを模倣することができます。variational autoencoders(VAE)、generative adversarial networks (GAN)、copulasなど、データをモデリングするためのアプローチもサポートしています。SDVは、生成されたデータの制約を強制したり、個人情報を再編集したりすることができます。実際、モデリングに必要なのは、簡単モードのTabularPresetクラスを使用した、このスニペットだけである:

metadata = Metadata()
metadata.add_table(name="nyctaxi_yellow", data=table_nyctaxi)

model = TabularPreset(name='FAST_ML', metadata=metadata.get_table_meta("nyctaxi_yellow"))
model.fit(table_nyctaxi)

model.sample(num_rows=5, randomize_samples=False)

Synthetic Data in Minutes

At a glance, it sure looks plausible! Also included are data quality reports, which give some sense of how well the model believes its results match original data:

Overall Quality Score: 75.18%

Properties:
Column Shapes: 66.88%
Column Pair Trends: 83.47%

Data Quality

これらのプロットは、各列の合成データの分布がオリジナルとどの程度一致しているか、また合成データと実データの相関関係を示しています。これらを0~100%の間で点数化し、全体として75%を与えています。これは「OK」です。(SDMetricsライブラリでもう少し詳しく説明しています)この時点では、なぜstore_and_fwd_flag列が他の列よりも忠実度が低いのかは不明です。

合成データの品質を評価する

その合成データを詳しく見てみると(おそらくDatabricksのData Visualizationタブを使用!)、問題があることがわかります:

  • MTAの税金やチップなど、金額がマイナスになっているものがある。
  • 旅客数と距離が0になることも
  • 距離は直線距離よりも時折、ありえないほど短くなることがある
  • 経度・緯度がニューヨークのどこにもない(あるいは緯度90度以上など全く無効な)場合がある
  • 金銭の額が小数点以下2桁以上ある
  • ピックアップの時間がドロップオフの時間より後になることもあれば、12時間以上の長時間のシフトになることもある

実際、これらの問題の多くは、元のデータセットに見いだされています。機械学習モデルと同じで、ゴミが入ればゴミが出る。明らかに問題があるデータをエミュレートしようとするよりも、元データの問題を修正する方が価値があります。簡単のために、明らかに悪いデータの行は、他の行と同様に削除することができます:

  • 金額がマイナスである
  • ドロップオフがピックアップの前にある、またはピックアップの後に不当に長い時間がかかる。
  • ロケ地はニューヨークの近郊
  • 距離が正しくない、または不当に大きい
  • 始点と終点で距離がありえないほど短くなる

本題に入りますが、改善されたフィルター付きのデータセットでやり直すと、品質スコアは82%になります。しかし、品質を向上させるためには、ソースデータを修正する以外にもやるべきことがあります。

制約条件の使用

上記は、実データと合成データが満たすべきいくつかの条件です。データを生成するモデルは、本来、生成する値を意味的に理解しているわけではありません。例えば、元のデータセットには、端数のある乗客数やマイナスの距離はない(少なくとも今はない)。優れたモデルであれば、一般的にこれを模倣するように学習しますが、これらが整数でなければならないことがわからない場合は、完璧に模倣できないことがあります。

SDVは、このような制約を表現する手段を提供します。これにより、モデリングプロセスが明らかに悪いデータを出力しないようにするための学習に時間を費やす必要がなくなります。制約は次のようなものです:

# Dropoff shouldn't be more than 12 hours after pickup, or before pickup
def is_duration_valid(column_names, data):
 pickup_col, dropoff_col = column_names
 return (data[dropoff_col] - data[pickup_col]) < np.timedelta64(12, 'h')

DurationValid = create_custom_constraint(is_valid_fn=is_duration_valid)
constraints += [DurationValid(column_names=["pickup_datetime", "dropoff_datetime"])]
constraints += [Inequality(low_column_name="pickup_datetime", high_column_name="dropoff_datetime")]

# Monetary amounts should be positive
constraints += [ScalarInequality(column_name=c, relation=">=", value=0) for c in
                ["fare_amount", "extra", "mta_tax", "tip_amount", "tolls_amount"]]
# Passengers should be a positive integer
constraints += [FixedIncrements(column_name="passenger_count", increment_value=1)]
constraints += [Positive(column_name="passenger_count")]
# Distance should be positive and not (say) more than 100 miles
constraints += [ScalarRange(column_name="trip_distance", low_value=0, high_value=100)]
# Lat/lon should be in some credible range around New York City
constraints += [ScalarRange(column_name=c, low_value=-76, high_value=-72) for c in ["pickup_longitude", "dropoff_longitude"]]
constraints += [ScalarRange(column_name=c, low_value=39, high_value=43) for c in ["pickup_latitude", "dropoff_latitude"]]

また、ユーザーが提供するロジックと複数のカラムを含むカスタム制約を記述することも可能です。例えば、ピックアップとドロップオフの緯度/経度、およびタクシー移動距離が与えられます。この2点間の移動距離は、2点間の直線距離よりも長くすることができますが、短くすることはできません!これは、ハバーシン距離の関係で、5つの列の間にある自明ではない要求関係だ。これをカスタム制約として書くのは簡単で、タクシーのGPSによる緯度経度の不正確さを考慮するために、少し余裕を持たせてもよいでしょう:

def is_trip_distance_valid(column_names, data):
  dist_col, from_lat, from_lon, to_lat, to_lon = column_names
  return data[dist_col] >= 0.9 * haversine_dist_miles(data[from_lat], data[from_lon], data[to_lat], data[to_lon])

TripDistanceValid = create_custom_constraint(is_valid_fn=is_trip_distance_valid)
constraints += [TripDistanceValid(column_names=["trip_distance", "pickup_latitude", "pickup_longitude", "dropoff_latitude", "dropoff_longitude"])]

再挑戦する前に、より強力なモデルも視野に入れておくとよいでしょう。

高度な合成データモデリング

上記で使用したSDVの簡単なTabularPresetのアプローチでは、Gaussian copulasを採用しています。聞き慣れない名前かもしれませんが、驚くほどシンプルで速く、多くの問題で効果的です。TabularPresetが問題に対してうまく機能しているのであれば、これ以上探す必要はありません。

複雑な問題では、より複雑なモデルがより良い結果をもたらす可能性があります。SDVは、GANやVAEに基づくアプローチもサポートしています。どちらもディープラーニングを採用したアイデアですが、その方法はさまざまです。GANは、データを生成するモデルと、合成データを検出するために学習するモデルの2つを互いに戦わせ、その出力が本物と見分けがつかなくなるまで生成モデルを改良していきます。VAEは、実データを解読するだけでなく、新しい合成データも空中から「解読」できるように、実データの暗号化を学習します。

どちらも計算量が多く、合理的な時間で処理するにはGPUが必要です。単純なアプローチではエミュレートが難しいデータセットや、カクテルパーティーで「GANを活用しているんですよ」と言えるようなデータセットであれば、SDVの CTGANTVAEがおすすめです。

この後のアップグレードされた例でTVAEを試すのは、もう手間ではありません。また、MLflowを追加することで、メトリクスのログを取り、さらにTVAEモデル自体をpredict関数でデータを増やすだけのモデルとして管理することも可能です:

# Wrapper convenience model that lets the SDV model "predict" new synthetic data
class SynthesizeModel(mlflow.pyfunc.PythonModel):
  def __init__(self, model):
    self.model = model

  def predict(self, context, model_input):
    return self.model.sample(num_rows=len(model_input))

use_gpu = True

with mlflow.start_run():
  metadata = Metadata()
  metadata.add_table(name="nyctaxi_yellow", data=table_nyctaxi)

  model = TVAE(constraints=constraints, batch_size=1000, epochs=500, cuda=use_gpu)
  model.fit(table_nyctaxi)
  sample = model.sample(num_rows=10000, randomize_samples=False)
  
  report = QualityReport()
  report.generate(table_nyctaxi, sample, metadata.get_table_meta("nyctaxi_yellow"))
 
  mlflow.log_metric("Quality Score", report.get_score())
  for (prop, score) in report.get_properties().to_numpy().tolist():
    mlflow.log_metric(prop, score)
    mlflow.log_dict(report.get_details(prop).to_dict(orient='records'), f"{prop}.json")
    prop_viz = report.get_visualization(prop)
    display(prop_viz)
    mlflow.log_figure(prop_viz, f"{prop}.png")

 if use_gpu:
   model._model.set_device('cpu')
 synthesize_model = SynthesizeModel(model)
 dummy_input = pd.DataFrame([True], columns=["dummy"]) # dummy value
 signature = infer_signature(dummy_input, synthesize_model.predict(None, dummy_input))
 mlflow.pyfunc.log_model("model", python_model=synthesize_model,
                         registered_model_name="sdv_synth_model",
                         input_example=dummy_input, signature=signature)

MLflowの使用について!  モデルを MLflow に登録すると、正確なモデルがバージョン管理されたレジストリに記録されます。 反復開発中に作成されたさまざまなモデルの記録を提供するだけでなく、MLflow レジストリを使用すると、他のユーザーにアクセスを許可して、モデルを取得し、合成データを自分で生成することができます。

実際、MLflow からこれらのプロットを確認できます。 品質は 83% までわずかに向上し、新しいプロットが利用可能になり、各列の合成の品質がそれ自体で分類されます。

 

Advanced Synthetic Data Modeling

合成データの生成

宿題が終われば、いくらでも合成データを簡単に生成できます。 ここでは、いくつかの新しく生成されたデータがデルタ テーブルに配置されます。 MLflow からモデルをロードし、データ生成モデルを使用する単純な Python 関数を記述し、それを Spark と並行してダミー入力に "適用" するだけです (UDF には入力が必要ですが、データ生成プロセスには実際には何も必要ありません) 単に結果を書き込みます。

sdv_model = mlflow.pyfunc.load_model("models:/sdv_synth_model/Production").\
  _model_impl.python_model.model

def synthesize_data(how_many_dfs):
  for how_many_df in how_many_dfs:
    yield sdv_model.sample(num_rows=how_many_df.sum().item(), output_file_path='disable')

how_many = len(table_nyctaxi)
partitions = 256
synth_df = spark.createDataFrame([(how_many // partitions,)] * partitions).\
 repartition(partitions).\
 mapInPandas(synthesize_data, schema=df.schema)

display(synth_df)

synth_data_path = ...
synth_df.write.format("delta").save(synth_data_path)

Generating Synthetic Data

Spark は、テラバイト単位で生成する必要がある場合に備えて、生成を並列化するのに非常に役立ちます。 これは、必要なだけ並列化します。

時間、場所などは確かに良くなっています。 pandas-profiling は、実際のデータと合成データを比較する方法を異なる方法で提供できます。 これはレポートのほんの一部です。

synth_data_df = spark.read.format("delta").load(synth_data_path).toPandas()

original_report = ProfileReport(table_nyctaxi, title='Original Data', minimal=True)
synth_report = ProfileReport(synth_data_df, title='Synthetic Data', minimal=True)
compare_report = original_report.compare(synth_report)
compare_report.config.html.navbar_show = False
compare_report.config.html.full_width = True

displayHTML(compare_report.to_html())

Synthetic Data

これにより、品質が 100% ではない理由が詳しくわかります。 たとえば、元のデータはかなり均一であったのに対し、合成データの乗車時間と降車時間には奇妙な不均一性があります。

今のところはこれで十分ですが、合成データ生成プロセスは、他の機械学習プロセスと同じようにここから繰り返され、データと合成プロセスの新しい改善点を発見して品質を向上させる可能性があります。

合成データを使ったモデリング

元のタスクは、データを構成するだけでなく、ヒントを予測することでした。 合成データで機械学習モデルを有用に構築できますか? まともなモデルがこのデータで何をするかを手動で理解するのに時間を費やすのではなく、Databricks Auto ML を使用して最初のパスを作成します。

databricks.automl.regress(
 spark.read.format("delta").load(synth_data_path),
 target_col="tip_amount",
 primary_metric="rmse",
 experiment_dir=tmp_experiment_dir,
 experiment_name="Synth models",
 timeout_minutes=120)

A few hours later:

Modeling with Synthetic Data

どのモデルが最もうまく機能したかの詳細はここでは問題になりません (おめでとうございます、 lightgbm) が、これは適切なモデルがヒントを予測するときに約 1.4 の RMSE を達成でき、R2 が 0.49 であることを示唆しています。

これは、ホールドアウトされた実際のデータのサンプルでモデルを評価した場合に有効ですか? はい、結局のところ、合成データに基づいて構築されたこの最良のモデルは、約 1.52 の RMSE と約 0.49 の R2 も達成しています。 これは優れたモデル パフォーマンスではありませんが、ひどいものではありません。

対照的に、合成データではなく実際のデータから始めていたら、ここで何が起こったでしょうか? Auto ML を再実行し、数時間休憩してから戻ってきて、次を見つけます。

Auto ML

まあ、それはかなり良いです。 さらに、この最良のモデルを実際のデータの同じホールドアウト サンプルでテストすると、同様の結果が得られます。RMSE は 0.94、R2 は 0.78 です。

この場合、実際のデータをモデル化すると、はるかに正確なモデルが生成されます。 それでも、合成データのモデリングによって何かが達成されました。 実際のデータにアクセスせずに、このデータセットでモデルを構築するための実行可能なアプローチであることが証明されました。 それはまずまずのモデルを生成し、他のユースケースでは、合成データのパフォーマンスは同等である可能性さえあります.

これを過小評価しないでください。 これは、機密データにアクセスできない請負業者などによって、モデリング アプローチがハッシュ アウトされる可能性があることを意味します。 パイプラインは、モデルではなく重要な成果物でした。 その後、パイプラインは他のチームによって実際のデータに適用される可能性があります。 チーム間でパイプラインの開発とデプロイを分割する方法の詳細については、Big Book of MLops を参照してください。

最後に、合成データはデータ拡張の戦略にもなります。 実際のデータにアクセスできるチームの場合、合成データを追加すると、モデルがわずかに改善される可能性があります。 好奇心のために、結果を繰り返さないでください。実際のデータと合成データを組み合わせて使用する Auto ML でのこの同じアプローチでは、0.95 の RMSE と 0.77 の R2 が得られます。 この場合は実質的に違いはありませんが、他の場合は可能性があります。

まとめ

ジェネレーティブ AI の力は、面白いチャットだけにとどまりません。 現実的な合成ビジネス データを作成できます。これは、機微な実際のデータへのアクセスを簡単に保護できない機械学習チームにとって有用な代役となる可能性があります。 SDV のようなツールを使用すると、このプロセスをわずか数行のコードで実行でき、結果のモデルとデータを管理するために Spark、Delta、および MLflow とうまく組み合わせることができます。

Try it now on Databricks!

Databricks 無料トライアル

関連記事

Dolly:オープンなモデルで ChatGPT の魔法を民主化

概要 Databricks では、従来のオープンソースの大規模言語モデル(LLM)を利用して ChatGPT のような命令追従能力を実現できることを確認しました。高品質な学習データを使用して 1 台のマシンで 30 分ほどトレーニングするだけです。また、命令追従能力の実現には、必ずしも最新のモデルや大規模なモデルは必要ないようです。GPT-3 のパラメータ数が 1750 億であるのに対し、私たちのモデルでは 60 億です。私たちはモデル Dolly のコードをオープンソース化しています。Dolly を Databricks 上でどのように再作成できるか、今回のブログではこのことについて詳しく解説します。 Dolly のようなモデルは LLM の民主化を促進します。LLM...
エンジニアリングのブログ一覧へ