因果探索ライブラリcausalnex
quantumblackって名前がかっこいい
以前quantumblack社がオープンソースで開発しているkedro
を紹介しました。
これ、kedro
開発者にも見つかっていろいろと反応をいただいて大変光栄でした。
今回は同社が同様にオープンソースで開発しているcausalnex
を使ってみます。
理由は単純。
ロゴが、かっこいい。
まだチュートリアル段階ですが、これだけでも十分に実用に足るんじゃなかろうか、
というくらいには強い。仕事のOSがUbuntuなら完全に優勝していた。
githubはこちら
github.com
基本的には公式チュートリアルに沿ってやっていきます。
進捗的には因果グラフをつくるところまで。
チュートリアルではベイジアンネットワークの作成ができていますが、
そこまで追いついていないので一旦保留。
下準備
causalnex
はgraphviz
を使って可視化を進めているようです。
graphviz
はネットワーク図を描画するツールパッケージで、
コイツは別途インストールが必要です。
きぬいとはUbuntu 20.04なのでこんな感じ。
なんか開発版がないと可視化でコケたので、開発版も入れておきます*1。
sudo apt install graphviz sudo apt install graphviz-dev
きぬいとはpipenv
で仮想環境を作って、VSCodeでコードを書いているので、
作業ディレクトリ上で以下を実行してpipenv
上にインストール。
この時pandas
とかmatplotlib
とかpygraphviz
などを合わせて入れておきます。
また、可視化にはIPython.display
からImage
モジュールを使っているので、
そのアタリを入れておきましょう。
pipenv install causalnex # ...あとはお好きなライブラリをインストール。 # Pipfileがlockされたらshell/run pipenv shell
実際ここで詰まったので、ここだけ書いたらあとはgithubからソース読んで欲しい(えー)。
causalnex
で因果ダイアグラムを描いていく。
今回はチュートリアルに即してUCIが提供する学力データを採用します。
入っている変数はリンクをたどれば見られるので、詳細はそちらを。
データとしては数値とカテゴリー変数が混ざったデータで、SEMるには良さそうなデータかも知れません。
causalnex
のチュートリアルでは、性別や年齢、学校などは"sensitive"として、
モデルから削除しています*2。
そこまでやったのがこちら。
# %% load data student_por = pd.read_csv('../sample_data/student-por.csv', sep=';') student_por.shape >> (649, 33) # %% # to drop sensitive features(to avoid statistical discrimination) drop_col = ['school', 'sex', 'age', 'Mjob', 'Fjob', 'reason', 'guardian'] student_por.drop(drop_col, axis=1, inplace=True)
今回はカテゴリー変数をLabelEncodingします。
カテゴリー変数の操作はもちろん、どのようにグラフ化したいかにも寄るでしょうから、
実務での問題設計に応じて対応すればよろしいと思います。
この辺はkaggleのベースラインモデルでもよく見るコーディングですなあ。
# %% preprocessing # label encoding le = LabelEncoder() categorical_features = list( student_por.select_dtypes(exclude=[np.number]).columns) categorical_features # %% for cat_col in categorical_features: student_por[cat_col] = le.fit_transform(student_por[cat_col])
さて、続いては因果ダイアグラムの作成です。
causalnex
はNo-Tearsアルゴリズムを使ってグラフを作成できます。
アルゴリズムについては追って読んでみます。本当に読みます。はい……
ただ、数分時間がかかります。行なのか列なのかはわかりませんが、
この規模で時間がかかるので、まあ、そういうことなんでしょう。
# %% NOTEARS algorithm structure # maybe some times to calculate. no_tears_sm = from_pandas(student_por) # %% visualize viz = plot_structure(no_tears_sm, graph_attributes={"scale": "0.5"}, all_node_attributes=NODE_STYLE.WEAK, all_edge_attributes=EDGE_STYLE.WEAK) Image(viz.draw(format='png'))
きぬいとの環境ではOSError
が起きて終わってしまいましたが*3、
公式チュートリアルではこんな画像が。
グラフ、画像がかっこいい。
ではなく、デフォルトでは完全グラフになってしまうようです。そりゃローカルPCじゃ落ちる。
チュートリアル通り、閾値を設けて足切りをしていきましょう。
足切りはremove_edges_below_threshold
で閾値を設定することで可能です。
# %% no_tears_sm.remove_edges_below_threshold(0.8) viz = plot_structure(no_tears_sm, graph_attributes={"scale": "0.5"}, all_node_attributes=NODE_STYLE.WEAK, all_edge_attributes=EDGE_STYLE.WEAK) Image(viz.draw(format='png'))
すると
ほう……かっこいい。
更にテコ入れします。グラフを見るとhigher
(高等教育を望んでいるか)がMedu
(母親の学歴)に向かって矢印が伸びています。
自身の学歴を高くしたいかどうかが、母親の学歴を規定する、という構造はちょっと説明が難しいですね。
むしろ、こういうのは逆方向か、あるいは関係ないか。
今回は「関係ない」という制約を置くことにします*4。
また、いくつかの変数に矢印を引くこともここで一気にやってみます。
# %% cut the relationship between higher and Medu no_tears_sm = from_pandas(student_por, tabu_edges=[('higher', 'Medu')], w_threshold=0.8) # %% add or remove edges no_tears_sm.add_edge('failure', 'G1') no_tears_sm.remove_edge('Pstatus', 'G1') no_tears_sm.remove_edge('address', 'G1') viz = plot_structure(no_tears_sm, graph_attributes={"scale": "0.5"}, all_node_attributes=NODE_STYLE.WEAK, all_edge_attributes=EDGE_STYLE.WEAK) Image(viz.draw(format='png'))
その結果こう。しっかりhigher
とMedu
の関係性が切れ、
いくつかの関係性を追加・削除できたようです。
そしてよく見ると、どこにも矢印が伸びていない点や、
Dalc
とWalc
にだけグラフができているなどがあります。
続く(であろう)ベイジアンネットワークには一旦これらの変数は使わないことにして、
もっとも大きなサブグラフを抜き出します。
# %% get the largest subgraph. no_tears_sm = no_tears_sm.get_largest_subgraph() viz = plot_structure(no_tears_sm, graph_attributes={"scale": "0.5"}, all_node_attributes=NODE_STYLE.WEAK, all_edge_attributes=EDGE_STYLE.WEAK) Image(viz.draw(format='png'))
結果はこう。
見やすくなった。
この因果ダイアグラムを使って、そのうちベイジアンネットワークを構築していきます。
……といったところで今回はここまで。
モデルは走らせると結構重いので、pickle
で保存しておきます。
# %% save sm filename = '../output/no_tears_sm.pkl' pickle.dump(no_tears_sm, open(filename, 'wb'))
完走した感想
いやまだ完走はしてませんが……
各APIが充実していて、因果ダイアグラムを
- 機械的に出す
- それを見て違和感があるところについて修正する
を繰り返してグラフ構造が得られるのは、なんだか面白いと思います。
次回はベイジアンネットワークをやっていく予定ですが、
ここからDAGベースでの線形回帰モデル、非線形回帰モデルなども実装可能なようです。
また、今回端折った「因果性」についても思った以上に詳しく記述があります。
あくまで「因果探索」のアルゴリズムですが、ここまでAPIが充実していれば、
実務にも応用はできそうな予感がします。
注意
causalnex
は強力な武器ですし、因果探索は実務上じっくり進めていく必要があります。
実際、矢印を引くには計算時間と計算リソースを大きく割きますし。
また、causalnex
ではLiNGAMは実装できない?模様です。
introductionにもあるとおり、
あくまで初心者(というよりは因果モデルをグラフ表現できても実装ができない)に
簡単にモデリングできるよう設計されているので、シンプルな実装になっているようです。
LiNGAMが射程範囲なのかどうかすら微妙ですが。
ともあれ、とりあえずベイジアンネットワーク作ってみっかw*5