Sparse Group LassoをやるためのRライブラリ{SGL}を使ってみた
この記事は
別にR Advent Calendarの記事ではありません(え?)
22日目に書くので、許してください。
仕事で使うかどうかも微妙なのですが、いくつかのグループごとにネストされた説明変数を使った回帰問題にぶち当たりました。
普通のLassoをするとグループ関係なしに変数が選択されてかなしい。
かたやGroup Lassoをするとグループ単位で変数が選択されて、これもかないい。
それらのいいとこ取りをしようというモデルがSparse Group Lassoのようです。
元論文はこちら。
Rによる実装があります。
Sparse Group Lassoって?
定式化は上記の論文を見るとわかります。*1
αがハイパーパラメータとして存在し、その大小でGroup単位での罰則を強めるか、通常Lassoの罰則に寄せるかを決定するようです。
そして罰則の強さはLambdaで決まる……
Elastic NetのRidge系罰則項をGroup Lassoの項に置き換えた感じですね。うん、シンプル!
上にも書いたようにRによる実装{SGL}があります。
雑に実装したので下記に置いときます。多分cloneすれば勝手に動くと思います。
挙動が謎だったのでとりあえずいろいろ調べてみているやつです。
とりあえずalpha
を0、1、0.5でやってみました。
あと理由は知りませんがハイパーパラメータのlambda
も好きに出せるっぽいので*2
とりあえず乱数をソートして2000個で、挙動としてどう変化するのか見てやろうと思います。
# load libraries library(SGL) library(tidyverse) # dummy data creation set.seed(1) n = 500; p = 50; size.groups = 12 index <- ceiling(1:p / size.groups) X = matrix(rnorm(n * p, mean = 0, sd = 2), ncol = p, nrow = n) for(k in 1:ncol(X)){ rand_param <- runif(1, min = -15, max = 15) beta = rnorm(n = ncol(X), mean = 0, sd = 1) * rlnorm(ncol(X), meanlog = rand_param) } plot(beta) y = X %*% beta + 0.1*rnorm(n) y <- scale(y) X <- scale(X) plot(y) sampledata <- data.frame(X = X, y = y) sampledata %>% head X.1 X.2 X.3 X.4 X.5 X.6 X.7 X.8 X.9 X.10 X.11 1 -0.6414465 0.11656388 1.12708617 0.8217993 -0.86300747 -1.7810168 0.6102893 0.5720952 -1.1480633 1.5970074 -1.5815030 2 0.1591014 -0.23732833 1.10427289 -0.8369340 -1.86180245 -0.6819537 0.2789401 0.4302766 0.7843899 -0.6977968 0.6754861 3 -0.8481556 -1.07566180 -0.85950186 0.8624771 1.55260675 -0.4824233 1.1341242 -0.5624105 0.5871475 -1.4750089 -1.7517314 4 1.5540989 0.05413111 0.21167903 -0.8515996 0.49180262 0.9805747 -0.8397948 -0.5117973 -1.3689292 -0.5347501 1.2547356 5 0.3032465 0.98130821 0.07169284 0.5311435 -0.06260705 -0.6033307 -1.5909031 0.3657704 -2.0589998 -2.0774948 1.1893811 6 -0.8331741 1.55102726 -1.64381076 -0.1424253 0.66257120 1.1096734 0.7927736 -0.6454655 0.6072623 0.5663933 -1.2883884 X.12 X.13 X.14 X.15 X.16 X.17 X.18 X.19 X.20 X.21 X.22 1 -1.1191492 -0.57737271 0.9276971 -1.3259908 -0.7864862 0.3145730 1.01503041 -1.27468766 1.33747964 -0.8594733 0.7489229 2 -0.9344597 -1.03395529 -0.3338209 0.9643525 0.6023669 -0.7403743 -0.08994914 -0.99628820 -1.45781750 -1.1286214 -0.6630628 3 1.6176808 -2.02131685 0.1566056 0.8718543 -1.0518508 -1.3504612 0.24933451 0.07011252 0.25529006 -1.1060711 -0.3409734 4 0.1769211 -0.03058095 -1.6769201 1.0737816 1.0119338 1.6852180 -0.67041524 0.69696057 -0.20977460 -1.2663307 0.3827110 5 -0.8911483 -0.24379348 0.7845712 -0.3456156 0.1775735 -1.4302605 -0.83252123 0.66825752 0.08848931 -0.5351517 -0.9904053 6 1.3422600 0.49593103 -0.8405994 -0.1245477 -1.0206286 -0.1241287 0.30933389 1.27680142 1.59185901 -0.5613513 -1.0614717 X.23 X.24 X.25 X.26 X.27 X.28 X.29 X.30 X.31 X.32 X.33 1 -1.3360927 -0.3558571 -0.91849513 -0.8386021 0.2309046 0.7105695 0.5809348 -0.5180316 -0.2220064 -0.48507511 0.8271434 2 1.0770012 -0.9650745 1.43381719 -0.6059831 -0.8242734 0.5115382 1.0571611 0.5902407 -0.1134155 0.61942507 -1.6343922 3 1.1605060 -0.0234316 1.66791176 0.3638292 -0.8531086 -2.2163279 0.3282538 -2.3061569 -0.4759732 -1.57227531 -0.2553172 4 0.3139110 0.3227936 0.44090516 0.9569211 -2.0013116 -0.1488372 -0.6526357 -1.3716650 -0.6985364 1.37573429 -1.0467824 5 -0.5072653 0.0950360 -0.06485584 0.7705720 -0.8184016 -0.1333218 -0.6760587 0.9379332 -0.8064612 -0.08186681 -1.1543301 6 -0.3612579 0.3924331 0.23630317 -1.3458728 1.4780950 1.5085669 -0.0522178 -0.8454674 -0.3487007 -1.19166789 -1.5318295 X.34 X.35 X.36 X.37 X.38 X.39 X.40 X.41 X.42 X.43 X.44 1 -0.8303923 0.9809046 -2.4903815 0.9782235 -1.380000417 -2.0443493 -0.9377986 0.1270102 0.5572240 0.5681752 -0.57069814 2 -2.2770677 0.2517978 -0.7295903 0.4797610 2.067938630 -0.4026967 1.4445560 0.1367522 0.2931986 -2.8776170 0.13664573 3 -2.3720160 1.1657603 1.6079233 -0.3428822 -1.647147807 -0.1105207 0.2644487 -0.7751125 -0.7782824 0.1389422 -0.24520146 4 0.9134221 -0.6848332 -0.6427748 -1.0358210 1.161129489 0.9647438 1.1808350 -2.1039515 -0.9130585 2.2672263 0.02408309 5 0.7720436 -0.6682041 -0.5376536 0.3289408 0.008394193 0.0270925 -0.6427501 0.9528659 -0.6446500 0.6816090 -0.39885402 6 0.3992770 0.8765229 0.9647004 2.0299603 2.188120377 1.9363709 -0.8951111 -0.4064273 -0.1268170 1.1371677 -0.91984778 X.45 X.46 X.47 X.48 X.49 X.50 y 1 -0.191373790 1.5314460 1.5191370 -0.84689213 0.6943624 -0.9548680 -0.2907983 2 -0.008471448 -0.5193922 0.1364774 -0.68132309 -0.6682177 0.1558025 -0.7982820 3 0.247918285 -2.5151220 -0.2335502 1.69160660 -1.4858869 0.7547085 1.2480950 4 -0.118767974 0.9607987 -0.2270655 -0.06780635 -1.8934025 -0.5497518 -2.3044130 5 0.005952484 0.1974666 -1.2365630 -1.17937816 0.5291765 -0.2172036 -0.3080377 6 0.062142767 0.9215631 1.1493869 -1.53989948 0.2133129 0.5121727 1.9143624
# modeling SGL fit <- list(x=X, y=y) # List型での読み込みっぽい ## デフォルトでSGL側でlambda作ってくれるけどこの辺統一して ## 結果が見たかったので適当に作る lambda <- sort(abs(rnorm(2000,mean = 0, sd = 0.0005)),decreasing = TRUE) plot(lambda) # alpha = 0 -> Group Lasso fitSGL_zero_alpha <- SGL::SGL(data = fit, index = index, type = 'linear', min.frac = 0.01, alpha = 0, verbose = TRUE, thresh = 0.1, gamma = 0.4, #nlam = 2000, lambdas = lambda) # alpha = 1 -> Normal Lasso fitSGL_one_alpha <- SGL::SGL(data = fit, index = index, type = 'linear', min.frac = 0.01, alpha = 1, verbose = TRUE, thresh = 0.1, gamma = 0.4, #nlam = 2000, lambdas = lambda) # alpha = 0.5 -> Sparse Group Lasso fitSGL_half_alpha <- SGL::SGL(data = fit, index = index, type = 'linear', min.frac = 0.01, alpha = 0.5, verbose = TRUE, thresh = 0.1, gamma = 0.4, #nlam = 2000, lambdas = lambda) result <- data.frame(X_name = colnames(sampledata)[-51], group = index, beta_nlam_200_zero = fitSGL_zero_alpha$beta[,200], beta_nlam_500_zero = fitSGL_zero_alpha$beta[,500], beta_nlam_1000_zero = fitSGL_zero_alpha$beta[,1000], beta_nlam_2000_zero = fitSGL_zero_alpha$beta[,2000], beta_nlam_200_one = fitSGL_one_alpha$beta[,200], beta_nlam_500_one = fitSGL_one_alpha$beta[,500], beta_nlam_1000_one = fitSGL_one_alpha$beta[,1000], beta_nlam_2000_one = fitSGL_one_alpha$beta[,2000], beta_nlam_200_half = fitSGL_half_alpha$beta[,200], beta_nlam_500_half = fitSGL_half_alpha$beta[,500], beta_nlam_1000_half = fitSGL_half_alpha$beta[,1000], beta_nlam_2000_half = fitSGL_half_alpha$beta[,2000])
結果は眠くなってきたので起きたら追記します。
alphaの強さでLasso←→Group Lassoという感じになります。
alphaを中間におけば、lambdaの強さでGroupを選びつつ、選ばれたGroup内でも更に重要な特徴量にパラメータが入り、それ以外は0となる感じです。
ただ、この{SGL}パッケージはcv.SGL
とかもあるのですが、CVした結果をpredict
しようとすると失敗するちょっと謎挙動な面もあります。
多分私がこの辺のリスト運用を十分に理解していないだけですが、私個人の目的は「説明モデル」なのでまあCVはいいとします。
起きたので結果貼っときます。
result %>% dplyr::select(X_name,group,contains('zero')) X_name group beta_nlam_200_zero beta_nlam_500_zero beta_nlam_1000_zero beta_nlam_2000_zero 1 X.1 1 -0.919070437 -0.91223722 -0.890147077 -0.82694598 2 X.2 1 -1.930531437 -2.19526196 -2.476415069 -2.93523804 3 X.3 1 -12.317403585 -13.23502437 -14.136662526 -15.47656029 4 X.4 1 -0.069421890 0.01577950 0.114029993 0.28688744 5 X.5 1 -1.432918096 -1.51957554 -1.602283187 -1.72154480 6 X.6 1 0.098980538 0.05132270 0.003071861 -0.06760882 7 X.7 1 1.659840314 1.73261807 1.799260229 1.88947418 8 X.8 1 0.903728861 0.97595933 1.049261612 1.16242585 9 X.9 1 0.405605859 0.55254002 0.718471323 1.00678492 10 X.10 1 0.217522682 0.20660601 0.187691813 0.14238394 11 X.11 1 -0.042543063 0.02625380 0.102333878 0.22922301 12 X.12 1 -0.084837994 -0.13775980 -0.197679733 -0.30108252 13 X.13 2 -2.957216078 -3.42081303 -3.892311519 -4.62764424 14 X.14 2 0.325786662 0.41627693 0.519415032 0.69859071 15 X.15 2 0.064942034 -0.01403526 -0.105939783 -0.26506662 16 X.16 2 -2.464187631 -2.81228176 -3.164572871 -3.71407507 17 X.17 2 -4.689638227 -5.40708826 -6.122918675 -7.20150107 18 X.18 2 -0.457018318 -0.52695380 -0.596593342 -0.70305740 19 X.19 2 0.744436815 0.87213738 1.003763762 1.21089624 20 X.20 2 -2.185698610 -2.49315625 -2.795311993 -3.24520822 21 X.21 2 0.083014664 0.08738530 0.088033853 0.08155590 22 X.22 2 0.056260046 0.10999228 0.175066289 0.29424011 23 X.23 2 -0.543896257 -0.70182277 -0.875619359 -1.16895024 24 X.24 2 -1.796261729 -2.06432250 -2.327537291 -2.71462352 25 X.25 3 -1.109027475 -1.13076916 -1.132408024 -1.09541273 26 X.26 3 -2.099496839 -2.26297384 -2.404382448 -2.57207836 27 X.27 3 4.961659623 5.74135666 6.561114600 7.88507367 28 X.28 3 -0.564386152 -0.57442040 -0.575266404 -0.56071205 29 X.29 3 -0.354246668 -0.32785197 -0.292340583 -0.22226264 30 X.30 3 -0.265786419 -0.28180872 -0.291834989 -0.29236051 31 X.31 3 -0.236560681 -0.35603860 -0.493107403 -0.73576728 32 X.32 3 -3.362084463 -3.72466569 -4.078353708 -4.59866290 33 X.33 3 0.816394337 0.86950099 0.910237164 0.94742844 34 X.34 3 0.993080139 1.18237003 1.381068024 1.69771805 35 X.35 3 0.362955455 0.38273267 0.399396209 0.42023231 36 X.36 3 2.331085686 2.57776024 2.818342879 3.17604085 37 X.37 4 -0.214400181 -0.34282474 -0.485030976 -0.71793845 38 X.38 4 -0.290136397 -0.38180041 -0.443617296 -0.47778922 39 X.39 4 -0.567452174 -0.78532932 -0.974745864 -1.20924492 40 X.40 4 0.163307858 0.24969092 0.345564578 0.51448056 41 X.41 4 -0.201417672 -0.29715830 -0.395285559 -0.54989507 42 X.42 4 0.444342602 0.61175382 0.752005251 0.91021997 43 X.43 4 -0.143933432 -0.20472800 -0.258366254 -0.31826404 44 X.44 4 0.187931419 0.35988375 0.581557757 1.00509886 45 X.45 4 0.450043458 0.71418455 1.007201340 1.49530596 46 X.46 4 1.404411102 2.09577419 2.805750553 3.91116715 47 X.47 4 0.523466943 0.73526838 0.925618307 1.16879890 48 X.48 4 -0.328033761 -0.47946227 -0.628296270 -0.84784914 49 X.49 5 0.009395451 0.02301674 0.042425085 0.08419528 50 X.50 5 0.263082542 0.49567410 0.734026180 1.10225507
result %>% dplyr::select(X_name,group,contains('one')) X_name group beta_nlam_200_one beta_nlam_500_one beta_nlam_1000_one beta_nlam_2000_one 1 X.1 1 -0.210155828 -0.39016041 -0.573232052 -0.82648639 2 X.2 1 -1.650955055 -2.01804364 -2.379829875 -2.93518858 3 X.3 1 -14.088599351 -14.49044986 -14.896715316 -15.47770626 4 X.4 1 0.000000000 0.00000000 0.000000000 0.28675074 5 X.5 1 -1.059950009 -1.24901107 -1.425890524 -1.72127099 6 X.6 1 0.000000000 0.00000000 0.000000000 -0.06778047 7 X.7 1 1.173343536 1.38532825 1.590406994 1.88915820 8 X.8 1 0.224326726 0.47838732 0.749474619 1.16201732 9 X.9 1 0.004239839 0.22331595 0.532925643 1.00665039 10 X.10 1 0.000000000 0.00000000 0.000000000 0.14191566 11 X.11 1 0.000000000 0.00000000 0.000000000 0.22895884 12 X.12 1 0.000000000 0.00000000 -0.001720018 -0.30084480 13 X.13 2 -3.496414479 -3.83873317 -4.139557920 -4.62796518 14 X.14 2 0.000000000 0.08242796 0.323029731 0.69832269 15 X.15 2 0.000000000 0.00000000 0.000000000 -0.26486919 16 X.16 2 -2.784422883 -3.04788857 -3.309957209 -3.71430873 17 X.17 2 -5.945841847 -6.30991982 -6.654541776 -7.20228229 18 X.18 2 0.000000000 -0.14016626 -0.347146673 -0.70270792 19 X.19 2 0.251420653 0.54367473 0.819606074 1.21069296 20 X.20 2 -2.320966895 -2.59190495 -2.854564218 -3.24534018 21 X.21 2 0.000000000 0.00000000 0.000000000 0.08105258 22 X.22 2 0.000000000 0.00000000 0.000000000 0.29378796 23 X.23 2 -0.144296150 -0.43963497 -0.707836624 -1.16877476 24 X.24 2 -1.713783585 -2.01100494 -2.289897812 -2.71452870 25 X.25 3 -0.690514976 -0.82392324 -0.953027264 -1.09515934 26 X.26 3 -1.976767955 -2.15285503 -2.320915379 -2.57190360 27 X.27 3 6.290702907 6.71390545 7.135755788 7.88603441 28 X.28 3 0.000000000 -0.07774605 -0.286099036 -0.56031342 29 X.29 3 0.000000000 0.00000000 0.000000000 -0.22174427 30 X.30 3 0.000000000 0.00000000 -0.044224606 -0.29196346 31 X.31 3 0.000000000 -0.06872451 -0.306951242 -0.73553077 32 X.32 3 -3.836494454 -4.06457583 -4.258222298 -4.59891608 33 X.33 3 0.268821646 0.45127310 0.653261829 0.94702598 34 X.34 3 0.461179441 0.78903853 1.137620739 1.69740740 35 X.35 3 0.000000000 0.00000000 0.147947029 0.41990880 36 X.36 3 2.406916552 2.63306670 2.822770708 3.17600633 37 X.37 4 0.000000000 -0.14959592 -0.372905407 -0.71782679 38 X.38 4 -0.206875866 -0.31432424 -0.384108249 -0.47768860 39 X.39 4 -0.786609108 -0.93279248 -1.044085430 -1.20932995 40 X.40 4 0.000000000 0.00000000 0.143428052 0.51421076 41 X.41 4 0.000000000 0.00000000 -0.173713257 -0.54956905 42 X.42 4 0.044241695 0.25796974 0.503832690 0.90982399 43 X.43 4 0.000000000 0.00000000 0.000000000 -0.31783463 44 X.44 4 0.000000000 0.12970019 0.452602846 1.00498571 45 X.45 4 0.361157846 0.64960377 0.968465960 1.49530939 46 X.46 4 2.699629400 3.03531711 3.362575526 3.91205990 47 X.47 4 0.240327949 0.44958473 0.716930588 1.16850450 48 X.48 4 -0.050930865 -0.25899369 -0.473274682 -0.84762844 49 X.49 5 0.000000000 0.00000000 0.000000000 0.08377711 50 X.50 5 0.477798540 0.63056910 0.797658128 1.10238347
result %>% dplyr::select(X_name,group,contains('half')) X_name group beta_nlam_200_half beta_nlam_500_half beta_nlam_1000_half beta_nlam_2000_half 1 X.1 1 -0.62071858 -0.68592510 -7.458937e-01 -0.82671621 2 X.2 1 -1.81613604 -2.10467881 -2.419279e+00 -2.93521332 3 X.3 1 -13.15534161 -13.81572666 -1.448429e+01 -15.47713327 4 X.4 1 0.00000000 0.00000000 4.505189e-02 0.28681910 5 X.5 1 -1.25428135 -1.38202880 -1.510880e+00 -1.72140789 6 X.6 1 0.00000000 0.00000000 0.000000e+00 -0.06769530 7 X.7 1 1.46681995 1.58414322 1.698336e+00 1.88931623 8 X.8 1 0.60066605 0.74427058 9.031424e-01 1.16222161 9 X.9 1 0.20904582 0.40815474 6.236052e-01 1.00671777 10 X.10 1 0.00000000 0.00000000 3.291013e-02 0.14214974 11 X.11 1 0.00000000 0.00000000 5.192025e-04 0.22909111 12 X.12 1 0.00000000 0.00000000 -1.037752e-01 -0.30096377 13 X.13 2 -3.21288959 -3.60442529 -4.002368e+00 -4.62780470 14 X.14 2 0.12202764 0.26392054 4.222991e-01 0.69845672 15 X.15 2 0.00000000 0.00000000 -3.092604e-02 -0.26496832 16 X.16 2 -2.61210417 -2.92004099 -3.232421e+00 -3.71419195 17 X.17 2 -5.25251744 -5.80951880 -6.370108e+00 -7.20189168 18 X.18 2 -0.22761540 -0.34978747 -4.809991e-01 -0.70288269 19 X.19 2 0.54802700 0.72532166 9.216449e-01 1.21079462 20 X.20 2 -2.27161813 -2.54633236 -2.828119e+00 -3.24527421 21 X.21 2 0.00000000 0.00000000 0.000000e+00 0.08130426 22 X.22 2 0.00000000 0.00000000 1.696232e-02 0.29401407 23 X.23 2 -0.36478506 -0.56698409 -7.951213e-01 -1.16886252 24 X.24 2 -1.79163500 -2.04976522 -2.309077e+00 -2.71457609 25 X.25 3 -0.97476956 -1.02116629 -1.051839e+00 -1.09528610 26 X.26 3 -2.06407823 -2.21940840 -2.370499e+00 -2.57199099 27 X.27 3 5.53571039 6.15829864 6.826187e+00 7.88555406 28 X.28 3 -0.29840729 -0.36788606 -4.453582e-01 -0.56051274 29 X.29 3 -0.06009490 -0.08937288 -1.266488e-01 -0.22200355 30 X.30 3 -0.04130928 -0.10497539 -1.701071e-01 -0.29216206 31 X.31 3 -0.05019059 -0.20577701 -3.951464e-01 -0.73564898 32 X.32 3 -3.59157651 -3.87620007 -4.160506e+00 -4.59878953 33 X.33 3 0.60190406 0.69637694 7.923339e-01 0.94722724 34 X.34 3 0.78515946 1.01238011 1.265741e+00 1.69756277 35 X.35 3 0.09410046 0.18374424 2.795800e-01 0.42007052 36 X.36 3 2.38218720 2.59638278 2.810935e+00 3.17602361 37 X.37 4 -0.14194709 -0.27645814 -4.400724e-01 -0.71788262 38 X.38 4 -0.28305880 -0.35714344 -4.217548e-01 -0.47773895 39 X.39 4 -0.68132091 -0.84978003 -1.004760e+00 -1.20928744 40 X.40 4 0.03022824 0.12795054 2.628507e-01 0.51434564 41 X.41 4 -0.09162862 -0.17734826 -2.917955e-01 -0.54973218 42 X.42 4 0.35231196 0.49767247 6.510938e-01 0.91002208 43 X.43 4 0.00000000 -0.04029963 -1.341940e-01 -0.31804938 44 X.44 4 0.07545063 0.26477989 5.228015e-01 1.00504226 45 X.45 4 0.44775922 0.69952400 9.988958e-01 1.49530767 46 X.46 4 1.90573821 2.47731859 3.052888e+00 3.91161352 47 X.47 4 0.46320515 0.64506385 8.493478e-01 1.16865175 48 X.48 4 -0.27004878 -0.40230047 -5.592959e-01 -0.84773886 49 X.49 5 0.00000000 0.00000000 0.000000e+00 0.08398628 50 X.50 5 0.34958263 0.55116199 7.628552e-01 1.10231930
alphaを0.5(気持ち的にちょうど中間)に設定しても普通のLassoに寄った結果になっているように見えます。
データ構造次第ですが、Group Lasso感を強めるためにはalphaはもうちょい弱い感じなのかもです。
……というかはてなブログ数式使えたよね?なんで数式反映されへんの?