HR(人事)にいるデータサイエンスパパ

HR(人事)データを分析するデータサイエンティスト。日々のデータ周りのことを書きます!

juliaで単回帰を実装してみた

はじめに

juliaで単回帰をフルスクラッチで実装してみました。
juliaの理解促進のために勢いで書いたので、間違いがあると思います笑

間違いや改善点があったら教えていただけると嬉しすぎます。


実装

以下のようにmoduleで実装しました。

module SelfSingleLinerRegression
    
using Printf
using Distributions
    
mutable struct SelfSingleLinerRegressionModel
    ETA::Float64
    cnt::Int64
    loss::Float64 #lossは0
    theta_0::Float64
    theta_1::Float64
    X
    y
end
    
function fit(model::SelfSingleLinerRegressionModel)
    
    for i in 1:model.cnt
        
        model.theta_0 -= model.ETA*sum((model.theta_0 .+ model.X.*model.theta_1) - model.y)
        model.theta_1 -= model.ETA*sum(((model.theta_0 .+ model.X.*model.theta_1) - model.y).*model.X)
        
        if i % 200 == 0
            model.loss = 0.5*sum((model.y .- (model.theta_0 .+ model.X.*model.theta_1)).^2)
            __print(i, model.theta_0, model.theta_1, model.loss)
        end
        
    end
    
end
    
function predict(model::SelfSingleLinerRegressionModel, X::Array)
    return model.theta_0 .+ X.*model.theta_1
end
            
function __print(cnt, theta_0, theta_1, loss)
    @printf("cnt: %.0f, theta_0: %.2f, theta_1: %.2f, loss: %.2f\n", cnt, theta_0, theta_1, loss)
end
    
end

方程式で解いてもいいのですが、今回は勾配降下法を用いました。
最急降下法で実装しているので、繊細なハイパーパラメータの調整が必要となっています笑


実際のデータに当てはめてみます。
学習率:0.001/学習の繰り返し回数:10000にします。

X = [1.0, 2.0, 3.0, 4.0, 5.0]
y = [20.2, 24.2, 32.4, 36.4, 40.1]

reg = SelfSingleLinerRegression.SelfSingleLinerRegressionModel(0.001, 10000, 0, rand(), rand(), X, y)
SelfSingleLinerRegression.fit(reg)

出力の最終行は以下になりました。

cnt: 10000, theta_0: 15.06, theta_1: 5.20, loss: 2.92


確認のためGLMライブラリの結果と照合してみます。

using DataFrames, GLM

df = DataFrame(X=X, y=y)
ols = lm(@formula(y ~ X), df)

以下が結果です。

StatsModels.DataFrameRegressionModel{LinearModel{LmResp{Array{Float64,1}},DensePredChol{Float64,LinearAlgebra.Cholesky{Float64,Array{Float64,2}}}},Array{Float64,2}}

Formula: y ~ 1 + X

Coefficients:
             Estimate Std.Error t value Pr(>|t|)
(Intercept)     15.06   1.46233 10.2987   0.0020
X                 5.2  0.440908 11.7938   0.0013

結果が一致しました!


まとめ

いくつかのデータセットで試してみましたが、収束しないことが多々あったので、データによっては全く使えないコードです汗
最急降下法じゃない実装も今後やっていきたいと思います。


juliaの書き方などの面でなにか参考になるものがあれば嬉しいです笑