と。

Github: https://github.com/8-u8

Sparse Group LassoをやるためのRライブラリ{SGL}を使ってみた

この記事は

別にR Advent Calendarの記事ではありません(え?)
22日目に書くので、許してください。

仕事で使うかどうかも微妙なのですが、いくつかのグループごとにネストされた説明変数を使った回帰問題にぶち当たりました。
普通のLassoをするとグループ関係なしに変数が選択されてかなしい。
かたやGroup Lassoをするとグループ単位で変数が選択されて、これもかないい。
それらのいいとこ取りをしようというモデルがSparse Group Lassoのようです。

元論文はこちら。

arxiv.org

Rによる実装があります。

github.com

Sparse Group Lassoって?

定式化は上記の論文を見るとわかります。*1

αがハイパーパラメータとして存在し、その大小でGroup単位での罰則を強めるか、通常Lassoの罰則に寄せるかを決定するようです。 そして罰則の強さはLambdaで決まる……
Elastic NetのRidge系罰則項をGroup Lassoの項に置き換えた感じですね。うん、シンプル!

上にも書いたようにRによる実装{SGL}があります。

雑に実装したので下記に置いときます。多分cloneすれば勝手に動くと思います。

github.com

挙動が謎だったのでとりあえずいろいろ調べてみているやつです。
とりあえずalphaを0、1、0.5でやってみました。 あと理由は知りませんがハイパーパラメータのlambdaも好きに出せるっぽいので*2
とりあえず乱数をソートして2000個で、挙動としてどう変化するのか見てやろうと思います。

# load libraries
library(SGL)
library(tidyverse)

# dummy data creation
set.seed(1)
n = 500; p = 50; size.groups = 12
index <- ceiling(1:p / size.groups)
X = matrix(rnorm(n * p, mean = 0, sd = 2), ncol = p, nrow = n)
for(k in 1:ncol(X)){
  rand_param <- runif(1, min = -15, max = 15)
  beta = rnorm(n = ncol(X), mean = 0, sd = 1) * rlnorm(ncol(X), meanlog = rand_param)
}
plot(beta)
y = X %*% beta + 0.1*rnorm(n)
y <- scale(y)
X <- scale(X)
plot(y)
sampledata <- data.frame(X = X, y = y)
sampledata %>% head

        X.1         X.2         X.3        X.4         X.5        X.6        X.7        X.8        X.9       X.10       X.11
1 -0.6414465  0.11656388  1.12708617  0.8217993 -0.86300747 -1.7810168  0.6102893  0.5720952 -1.1480633  1.5970074 -1.5815030
2  0.1591014 -0.23732833  1.10427289 -0.8369340 -1.86180245 -0.6819537  0.2789401  0.4302766  0.7843899 -0.6977968  0.6754861
3 -0.8481556 -1.07566180 -0.85950186  0.8624771  1.55260675 -0.4824233  1.1341242 -0.5624105  0.5871475 -1.4750089 -1.7517314
4  1.5540989  0.05413111  0.21167903 -0.8515996  0.49180262  0.9805747 -0.8397948 -0.5117973 -1.3689292 -0.5347501  1.2547356
5  0.3032465  0.98130821  0.07169284  0.5311435 -0.06260705 -0.6033307 -1.5909031  0.3657704 -2.0589998 -2.0774948  1.1893811
6 -0.8331741  1.55102726 -1.64381076 -0.1424253  0.66257120  1.1096734  0.7927736 -0.6454655  0.6072623  0.5663933 -1.2883884
        X.12        X.13       X.14       X.15       X.16       X.17        X.18        X.19        X.20       X.21       X.22
1 -1.1191492 -0.57737271  0.9276971 -1.3259908 -0.7864862  0.3145730  1.01503041 -1.27468766  1.33747964 -0.8594733  0.7489229
2 -0.9344597 -1.03395529 -0.3338209  0.9643525  0.6023669 -0.7403743 -0.08994914 -0.99628820 -1.45781750 -1.1286214 -0.6630628
3  1.6176808 -2.02131685  0.1566056  0.8718543 -1.0518508 -1.3504612  0.24933451  0.07011252  0.25529006 -1.1060711 -0.3409734
4  0.1769211 -0.03058095 -1.6769201  1.0737816  1.0119338  1.6852180 -0.67041524  0.69696057 -0.20977460 -1.2663307  0.3827110
5 -0.8911483 -0.24379348  0.7845712 -0.3456156  0.1775735 -1.4302605 -0.83252123  0.66825752  0.08848931 -0.5351517 -0.9904053
6  1.3422600  0.49593103 -0.8405994 -0.1245477 -1.0206286 -0.1241287  0.30933389  1.27680142  1.59185901 -0.5613513 -1.0614717
        X.23       X.24        X.25       X.26       X.27       X.28       X.29       X.30       X.31        X.32       X.33
1 -1.3360927 -0.3558571 -0.91849513 -0.8386021  0.2309046  0.7105695  0.5809348 -0.5180316 -0.2220064 -0.48507511  0.8271434
2  1.0770012 -0.9650745  1.43381719 -0.6059831 -0.8242734  0.5115382  1.0571611  0.5902407 -0.1134155  0.61942507 -1.6343922
3  1.1605060 -0.0234316  1.66791176  0.3638292 -0.8531086 -2.2163279  0.3282538 -2.3061569 -0.4759732 -1.57227531 -0.2553172
4  0.3139110  0.3227936  0.44090516  0.9569211 -2.0013116 -0.1488372 -0.6526357 -1.3716650 -0.6985364  1.37573429 -1.0467824
5 -0.5072653  0.0950360 -0.06485584  0.7705720 -0.8184016 -0.1333218 -0.6760587  0.9379332 -0.8064612 -0.08186681 -1.1543301
6 -0.3612579  0.3924331  0.23630317 -1.3458728  1.4780950  1.5085669 -0.0522178 -0.8454674 -0.3487007 -1.19166789 -1.5318295
        X.34       X.35       X.36       X.37         X.38       X.39       X.40       X.41       X.42       X.43        X.44
1 -0.8303923  0.9809046 -2.4903815  0.9782235 -1.380000417 -2.0443493 -0.9377986  0.1270102  0.5572240  0.5681752 -0.57069814
2 -2.2770677  0.2517978 -0.7295903  0.4797610  2.067938630 -0.4026967  1.4445560  0.1367522  0.2931986 -2.8776170  0.13664573
3 -2.3720160  1.1657603  1.6079233 -0.3428822 -1.647147807 -0.1105207  0.2644487 -0.7751125 -0.7782824  0.1389422 -0.24520146
4  0.9134221 -0.6848332 -0.6427748 -1.0358210  1.161129489  0.9647438  1.1808350 -2.1039515 -0.9130585  2.2672263  0.02408309
5  0.7720436 -0.6682041 -0.5376536  0.3289408  0.008394193  0.0270925 -0.6427501  0.9528659 -0.6446500  0.6816090 -0.39885402
6  0.3992770  0.8765229  0.9647004  2.0299603  2.188120377  1.9363709 -0.8951111 -0.4064273 -0.1268170  1.1371677 -0.91984778
          X.45       X.46       X.47        X.48       X.49       X.50          y
1 -0.191373790  1.5314460  1.5191370 -0.84689213  0.6943624 -0.9548680 -0.2907983
2 -0.008471448 -0.5193922  0.1364774 -0.68132309 -0.6682177  0.1558025 -0.7982820
3  0.247918285 -2.5151220 -0.2335502  1.69160660 -1.4858869  0.7547085  1.2480950
4 -0.118767974  0.9607987 -0.2270655 -0.06780635 -1.8934025 -0.5497518 -2.3044130
5  0.005952484  0.1974666 -1.2365630 -1.17937816  0.5291765 -0.2172036 -0.3080377
6  0.062142767  0.9215631  1.1493869 -1.53989948  0.2133129  0.5121727  1.9143624
# modeling SGL
fit <- list(x=X, y=y) # List型での読み込みっぽい

## デフォルトでSGL側でlambda作ってくれるけどこの辺統一して
## 結果が見たかったので適当に作る
lambda <- sort(abs(rnorm(2000,mean = 0, sd = 0.0005)),decreasing = TRUE)
plot(lambda)

# alpha = 0 -> Group Lasso
fitSGL_zero_alpha <- SGL::SGL(data = fit,
                              index = index,
                              type = 'linear',
                              min.frac = 0.01,
                              alpha = 0,
                              verbose = TRUE,
                              thresh  = 0.1,
                              gamma = 0.4,
                              #nlam = 2000,
                              lambdas = lambda)

# alpha = 1 -> Normal Lasso
fitSGL_one_alpha <- SGL::SGL(data = fit,
                             index = index,
                             type = 'linear',
                             min.frac = 0.01,
                             alpha = 1,
                             verbose = TRUE,
                             thresh  = 0.1,
                             gamma = 0.4,
                             #nlam = 2000,
                             lambdas = lambda)

# alpha = 0.5 -> Sparse Group Lasso
fitSGL_half_alpha <- SGL::SGL(data = fit,
                              index = index,
                              type = 'linear',
                              min.frac = 0.01,
                              alpha = 0.5,
                              verbose = TRUE,
                              thresh  = 0.1,
                              gamma = 0.4,
                              #nlam = 2000,
                              lambdas = lambda)

result <- data.frame(X_name          = colnames(sampledata)[-51],
                     group           = index,
                     beta_nlam_200_zero   = fitSGL_zero_alpha$beta[,200],
                     beta_nlam_500_zero   = fitSGL_zero_alpha$beta[,500],
                     beta_nlam_1000_zero  = fitSGL_zero_alpha$beta[,1000],
                     beta_nlam_2000_zero  = fitSGL_zero_alpha$beta[,2000],
                     beta_nlam_200_one    = fitSGL_one_alpha$beta[,200],
                     beta_nlam_500_one    = fitSGL_one_alpha$beta[,500],
                     beta_nlam_1000_one   = fitSGL_one_alpha$beta[,1000],
                     beta_nlam_2000_one   = fitSGL_one_alpha$beta[,2000],
                     beta_nlam_200_half    = fitSGL_half_alpha$beta[,200],
                     beta_nlam_500_half    = fitSGL_half_alpha$beta[,500],
                     beta_nlam_1000_half   = fitSGL_half_alpha$beta[,1000],
                     beta_nlam_2000_half   = fitSGL_half_alpha$beta[,2000])

結果は眠くなってきたので起きたら追記します。
alphaの強さでLasso←→Group Lassoという感じになります。 alphaを中間におけば、lambdaの強さでGroupを選びつつ、選ばれたGroup内でも更に重要な特徴量にパラメータが入り、それ以外は0となる感じです。

ただ、この{SGL}パッケージはcv.SGLとかもあるのですが、CVした結果をpredictしようとすると失敗するちょっと謎挙動な面もあります。

多分私がこの辺のリスト運用を十分に理解していないだけですが、私個人の目的は「説明モデル」なのでまあCVはいいとします。

起きたので結果貼っときます。

result %>% dplyr::select(X_name,group,contains('zero'))
   X_name group beta_nlam_200_zero beta_nlam_500_zero beta_nlam_1000_zero beta_nlam_2000_zero
1     X.1     1       -0.919070437        -0.91223722        -0.890147077         -0.82694598
2     X.2     1       -1.930531437        -2.19526196        -2.476415069         -2.93523804
3     X.3     1      -12.317403585       -13.23502437       -14.136662526        -15.47656029
4     X.4     1       -0.069421890         0.01577950         0.114029993          0.28688744
5     X.5     1       -1.432918096        -1.51957554        -1.602283187         -1.72154480
6     X.6     1        0.098980538         0.05132270         0.003071861         -0.06760882
7     X.7     1        1.659840314         1.73261807         1.799260229          1.88947418
8     X.8     1        0.903728861         0.97595933         1.049261612          1.16242585
9     X.9     1        0.405605859         0.55254002         0.718471323          1.00678492
10   X.10     1        0.217522682         0.20660601         0.187691813          0.14238394
11   X.11     1       -0.042543063         0.02625380         0.102333878          0.22922301
12   X.12     1       -0.084837994        -0.13775980        -0.197679733         -0.30108252
13   X.13     2       -2.957216078        -3.42081303        -3.892311519         -4.62764424
14   X.14     2        0.325786662         0.41627693         0.519415032          0.69859071
15   X.15     2        0.064942034        -0.01403526        -0.105939783         -0.26506662
16   X.16     2       -2.464187631        -2.81228176        -3.164572871         -3.71407507
17   X.17     2       -4.689638227        -5.40708826        -6.122918675         -7.20150107
18   X.18     2       -0.457018318        -0.52695380        -0.596593342         -0.70305740
19   X.19     2        0.744436815         0.87213738         1.003763762          1.21089624
20   X.20     2       -2.185698610        -2.49315625        -2.795311993         -3.24520822
21   X.21     2        0.083014664         0.08738530         0.088033853          0.08155590
22   X.22     2        0.056260046         0.10999228         0.175066289          0.29424011
23   X.23     2       -0.543896257        -0.70182277        -0.875619359         -1.16895024
24   X.24     2       -1.796261729        -2.06432250        -2.327537291         -2.71462352
25   X.25     3       -1.109027475        -1.13076916        -1.132408024         -1.09541273
26   X.26     3       -2.099496839        -2.26297384        -2.404382448         -2.57207836
27   X.27     3        4.961659623         5.74135666         6.561114600          7.88507367
28   X.28     3       -0.564386152        -0.57442040        -0.575266404         -0.56071205
29   X.29     3       -0.354246668        -0.32785197        -0.292340583         -0.22226264
30   X.30     3       -0.265786419        -0.28180872        -0.291834989         -0.29236051
31   X.31     3       -0.236560681        -0.35603860        -0.493107403         -0.73576728
32   X.32     3       -3.362084463        -3.72466569        -4.078353708         -4.59866290
33   X.33     3        0.816394337         0.86950099         0.910237164          0.94742844
34   X.34     3        0.993080139         1.18237003         1.381068024          1.69771805
35   X.35     3        0.362955455         0.38273267         0.399396209          0.42023231
36   X.36     3        2.331085686         2.57776024         2.818342879          3.17604085
37   X.37     4       -0.214400181        -0.34282474        -0.485030976         -0.71793845
38   X.38     4       -0.290136397        -0.38180041        -0.443617296         -0.47778922
39   X.39     4       -0.567452174        -0.78532932        -0.974745864         -1.20924492
40   X.40     4        0.163307858         0.24969092         0.345564578          0.51448056
41   X.41     4       -0.201417672        -0.29715830        -0.395285559         -0.54989507
42   X.42     4        0.444342602         0.61175382         0.752005251          0.91021997
43   X.43     4       -0.143933432        -0.20472800        -0.258366254         -0.31826404
44   X.44     4        0.187931419         0.35988375         0.581557757          1.00509886
45   X.45     4        0.450043458         0.71418455         1.007201340          1.49530596
46   X.46     4        1.404411102         2.09577419         2.805750553          3.91116715
47   X.47     4        0.523466943         0.73526838         0.925618307          1.16879890
48   X.48     4       -0.328033761        -0.47946227        -0.628296270         -0.84784914
49   X.49     5        0.009395451         0.02301674         0.042425085          0.08419528
50   X.50     5        0.263082542         0.49567410         0.734026180          1.10225507
 result %>% dplyr::select(X_name,group,contains('one'))
   X_name group beta_nlam_200_one beta_nlam_500_one beta_nlam_1000_one beta_nlam_2000_one
1     X.1     1      -0.210155828       -0.39016041       -0.573232052        -0.82648639
2     X.2     1      -1.650955055       -2.01804364       -2.379829875        -2.93518858
3     X.3     1     -14.088599351      -14.49044986      -14.896715316       -15.47770626
4     X.4     1       0.000000000        0.00000000        0.000000000         0.28675074
5     X.5     1      -1.059950009       -1.24901107       -1.425890524        -1.72127099
6     X.6     1       0.000000000        0.00000000        0.000000000        -0.06778047
7     X.7     1       1.173343536        1.38532825        1.590406994         1.88915820
8     X.8     1       0.224326726        0.47838732        0.749474619         1.16201732
9     X.9     1       0.004239839        0.22331595        0.532925643         1.00665039
10   X.10     1       0.000000000        0.00000000        0.000000000         0.14191566
11   X.11     1       0.000000000        0.00000000        0.000000000         0.22895884
12   X.12     1       0.000000000        0.00000000       -0.001720018        -0.30084480
13   X.13     2      -3.496414479       -3.83873317       -4.139557920        -4.62796518
14   X.14     2       0.000000000        0.08242796        0.323029731         0.69832269
15   X.15     2       0.000000000        0.00000000        0.000000000        -0.26486919
16   X.16     2      -2.784422883       -3.04788857       -3.309957209        -3.71430873
17   X.17     2      -5.945841847       -6.30991982       -6.654541776        -7.20228229
18   X.18     2       0.000000000       -0.14016626       -0.347146673        -0.70270792
19   X.19     2       0.251420653        0.54367473        0.819606074         1.21069296
20   X.20     2      -2.320966895       -2.59190495       -2.854564218        -3.24534018
21   X.21     2       0.000000000        0.00000000        0.000000000         0.08105258
22   X.22     2       0.000000000        0.00000000        0.000000000         0.29378796
23   X.23     2      -0.144296150       -0.43963497       -0.707836624        -1.16877476
24   X.24     2      -1.713783585       -2.01100494       -2.289897812        -2.71452870
25   X.25     3      -0.690514976       -0.82392324       -0.953027264        -1.09515934
26   X.26     3      -1.976767955       -2.15285503       -2.320915379        -2.57190360
27   X.27     3       6.290702907        6.71390545        7.135755788         7.88603441
28   X.28     3       0.000000000       -0.07774605       -0.286099036        -0.56031342
29   X.29     3       0.000000000        0.00000000        0.000000000        -0.22174427
30   X.30     3       0.000000000        0.00000000       -0.044224606        -0.29196346
31   X.31     3       0.000000000       -0.06872451       -0.306951242        -0.73553077
32   X.32     3      -3.836494454       -4.06457583       -4.258222298        -4.59891608
33   X.33     3       0.268821646        0.45127310        0.653261829         0.94702598
34   X.34     3       0.461179441        0.78903853        1.137620739         1.69740740
35   X.35     3       0.000000000        0.00000000        0.147947029         0.41990880
36   X.36     3       2.406916552        2.63306670        2.822770708         3.17600633
37   X.37     4       0.000000000       -0.14959592       -0.372905407        -0.71782679
38   X.38     4      -0.206875866       -0.31432424       -0.384108249        -0.47768860
39   X.39     4      -0.786609108       -0.93279248       -1.044085430        -1.20932995
40   X.40     4       0.000000000        0.00000000        0.143428052         0.51421076
41   X.41     4       0.000000000        0.00000000       -0.173713257        -0.54956905
42   X.42     4       0.044241695        0.25796974        0.503832690         0.90982399
43   X.43     4       0.000000000        0.00000000        0.000000000        -0.31783463
44   X.44     4       0.000000000        0.12970019        0.452602846         1.00498571
45   X.45     4       0.361157846        0.64960377        0.968465960         1.49530939
46   X.46     4       2.699629400        3.03531711        3.362575526         3.91205990
47   X.47     4       0.240327949        0.44958473        0.716930588         1.16850450
48   X.48     4      -0.050930865       -0.25899369       -0.473274682        -0.84762844
49   X.49     5       0.000000000        0.00000000        0.000000000         0.08377711
50   X.50     5       0.477798540        0.63056910        0.797658128         1.10238347
 result %>% dplyr::select(X_name,group,contains('half'))
   X_name group beta_nlam_200_half beta_nlam_500_half beta_nlam_1000_half beta_nlam_2000_half
1     X.1     1        -0.62071858        -0.68592510       -7.458937e-01         -0.82671621
2     X.2     1        -1.81613604        -2.10467881       -2.419279e+00         -2.93521332
3     X.3     1       -13.15534161       -13.81572666       -1.448429e+01        -15.47713327
4     X.4     1         0.00000000         0.00000000        4.505189e-02          0.28681910
5     X.5     1        -1.25428135        -1.38202880       -1.510880e+00         -1.72140789
6     X.6     1         0.00000000         0.00000000        0.000000e+00         -0.06769530
7     X.7     1         1.46681995         1.58414322        1.698336e+00          1.88931623
8     X.8     1         0.60066605         0.74427058        9.031424e-01          1.16222161
9     X.9     1         0.20904582         0.40815474        6.236052e-01          1.00671777
10   X.10     1         0.00000000         0.00000000        3.291013e-02          0.14214974
11   X.11     1         0.00000000         0.00000000        5.192025e-04          0.22909111
12   X.12     1         0.00000000         0.00000000       -1.037752e-01         -0.30096377
13   X.13     2        -3.21288959        -3.60442529       -4.002368e+00         -4.62780470
14   X.14     2         0.12202764         0.26392054        4.222991e-01          0.69845672
15   X.15     2         0.00000000         0.00000000       -3.092604e-02         -0.26496832
16   X.16     2        -2.61210417        -2.92004099       -3.232421e+00         -3.71419195
17   X.17     2        -5.25251744        -5.80951880       -6.370108e+00         -7.20189168
18   X.18     2        -0.22761540        -0.34978747       -4.809991e-01         -0.70288269
19   X.19     2         0.54802700         0.72532166        9.216449e-01          1.21079462
20   X.20     2        -2.27161813        -2.54633236       -2.828119e+00         -3.24527421
21   X.21     2         0.00000000         0.00000000        0.000000e+00          0.08130426
22   X.22     2         0.00000000         0.00000000        1.696232e-02          0.29401407
23   X.23     2        -0.36478506        -0.56698409       -7.951213e-01         -1.16886252
24   X.24     2        -1.79163500        -2.04976522       -2.309077e+00         -2.71457609
25   X.25     3        -0.97476956        -1.02116629       -1.051839e+00         -1.09528610
26   X.26     3        -2.06407823        -2.21940840       -2.370499e+00         -2.57199099
27   X.27     3         5.53571039         6.15829864        6.826187e+00          7.88555406
28   X.28     3        -0.29840729        -0.36788606       -4.453582e-01         -0.56051274
29   X.29     3        -0.06009490        -0.08937288       -1.266488e-01         -0.22200355
30   X.30     3        -0.04130928        -0.10497539       -1.701071e-01         -0.29216206
31   X.31     3        -0.05019059        -0.20577701       -3.951464e-01         -0.73564898
32   X.32     3        -3.59157651        -3.87620007       -4.160506e+00         -4.59878953
33   X.33     3         0.60190406         0.69637694        7.923339e-01          0.94722724
34   X.34     3         0.78515946         1.01238011        1.265741e+00          1.69756277
35   X.35     3         0.09410046         0.18374424        2.795800e-01          0.42007052
36   X.36     3         2.38218720         2.59638278        2.810935e+00          3.17602361
37   X.37     4        -0.14194709        -0.27645814       -4.400724e-01         -0.71788262
38   X.38     4        -0.28305880        -0.35714344       -4.217548e-01         -0.47773895
39   X.39     4        -0.68132091        -0.84978003       -1.004760e+00         -1.20928744
40   X.40     4         0.03022824         0.12795054        2.628507e-01          0.51434564
41   X.41     4        -0.09162862        -0.17734826       -2.917955e-01         -0.54973218
42   X.42     4         0.35231196         0.49767247        6.510938e-01          0.91002208
43   X.43     4         0.00000000        -0.04029963       -1.341940e-01         -0.31804938
44   X.44     4         0.07545063         0.26477989        5.228015e-01          1.00504226
45   X.45     4         0.44775922         0.69952400        9.988958e-01          1.49530767
46   X.46     4         1.90573821         2.47731859        3.052888e+00          3.91161352
47   X.47     4         0.46320515         0.64506385        8.493478e-01          1.16865175
48   X.48     4        -0.27004878        -0.40230047       -5.592959e-01         -0.84773886
49   X.49     5         0.00000000         0.00000000        0.000000e+00          0.08398628
50   X.50     5         0.34958263         0.55116199        7.628552e-01          1.10231930

alphaを0.5(気持ち的にちょうど中間)に設定しても普通のLassoに寄った結果になっているように見えます。
データ構造次第ですが、Group Lasso感を強めるためにはalphaはもうちょい弱い感じなのかもです。
……というかはてなブログ数式使えたよね?なんで数式反映されへんの?

*1:よく見る形になっているのは上記論文よりこの論文

*2:ドキュメントでは「とくに思うところがなければ俺たちに任せてほしい」って書いてるので、運用上は任せてやってもいいと思います。