§ Small Haskell MCMC implementation
We create a simple monad called PL
which allows for a single operation: sampling
from a uniform distribution. We then exploit this to implement MCMC using metropolis hastings,
which is used to sample from arbitrary distributions. Bonus is a small library to render sparklines
in the CLI.
For next time:
- Using applicative to speed up computations by exploiting parallelism
- Conditioning of a distribution wrt a variable
§ Source code
import System.Random
import Data.List(sort, nub)
import Data.Proxy
import Control.Monad (replicateM)
import qualified Data.Map as M
mLoop :: Monad m =>
(a -> m a)
-> Int
-> a
-> m a
mLoop _ 0 a = return a
mLoop f n a = f a >>= mLoop f (n - 1)
sparkchars :: String
sparkchars = "_▁▂▃▄▅▆▇█"
num2spark :: RealFrac a => a
-> a
-> Char
num2spark maxv curv =
sparkchars !!
(floor $ (curv / maxv) * (fromIntegral (length sparkchars - 1)))
series2spark :: RealFrac a => [a] -> String
series2spark vs =
let maxv = if null vs then 0 else maximum vs
in map (num2spark maxv) vs
seriesPrintSpark :: RealFrac a => [a] -> IO ()
seriesPrintSpark = putStrLn . series2spark
type F = Float
newtype P = P { unP :: Float } deriving(Num)
newtype D a = D { runD :: a -> P }
uniform :: Int -> D a
uniform n =
D $ \_ -> P $ 1.0 / (fromIntegral $ n)
(>$<) :: Contravariant f => (b -> a) -> f a -> f b
(>$<) = cofmap
instance Contravariant D where
cofmap f (D d) = D (d . f)
normalD :: Float -> D Float
normalD mu = D $ \f -> P $ exp (- ((f-mu)^2))
polyD :: Float -> D Float
polyD p = D $ \f -> P $ if 1 <= f && f <= 2 then (f ** p) * (p + 1) / (2 ** (p+1) - 1) else 0
class Contravariant f where
cofmap :: (b -> a) -> f a -> f b
data PL next where
Ret :: next -> PL next
Sample01 :: (Float -> PL next) -> PL next
instance Monad PL where
return = Ret
(Ret a) >>= f = f a
(Sample01 float2plnext) >>= next2next' =
Sample01 $ \f -> float2plnext f >>= next2next'
instance Applicative PL where
pure = return
ff <*> fx = do
f <- ff
x <- fx
return $ f x
instance Functor PL where
fmap f plx = do
x <- plx
return $ f x
sample01 :: PL Float
sample01 = Sample01 Ret
mhStep :: (a -> Float)
-> (a -> PL a)
-> a
-> PL a
mhStep f q a = do
a' <- q a
let alpha = f a' / f a
u <- sample01
return $ if u <= alpha then a' else a
class MCMC a where
arbitrary :: a
uniform2val :: Float -> a
instance MCMC Float where
arbitrary = 0
uniform2val v = tan (-pi/2 + pi * v)
mh :: (a -> Float)
-> (a -> PL a)
-> a
-> PL a
mh f q a = mLoop (mhStep f q) 100 $ a
mhD :: MCMC a => D a -> PL a
mhD (D d) =
let
scorer = (unP . d)
proposal _ = do
f <- sample01
return $ uniform2val f
in mh scorer proposal arbitrary
sample :: RandomGen g => g -> PL a -> (a, g)
sample g (Ret a) = (a, g)
sample g (Sample01 f2plnext) = let (f, g') = random g in sample g' (f2plnext f)
samples :: RandomGen g => Int -> g -> PL a -> ([a], g)
samples 0 g _ = ([], g)
samples n g pl = let (a, g') = sample g pl
(as, g'') = samples (n - 1) g' pl
in (a:as, g'')
occurFrac :: (Eq a) => [a] -> a -> Float
occurFrac as a =
let noccur = length (filter (==a) as)
n = length as
in (fromIntegral noccur) / (fromIntegral n)
distribution :: (Eq a, Num a, RandomGen g) => Int -> g -> PL a -> (D a, g)
distribution n g pl =
let (as, g') = samples n g pl in (D (\a -> P (occurFrac as a)), g')
coin :: Float -> PL Int
coin p1 = do
Sample01 (\f -> Ret $ if f < p1 then 1 else 0)
histogram :: Int
-> [Float]
-> [Int]
histogram nbuckets as =
let
minv :: Float
minv = minimum as
maxv :: Float
maxv = maximum as
perbucket :: Float
perbucket = (maxv - minv) / (fromIntegral nbuckets)
bucket :: Float -> Int
bucket v = floor (v / perbucket)
bucketed :: M.Map Int Int
bucketed = foldl (\m v -> M.insertWith (+) (bucket v) 1 m) mempty as
in map snd . M.toList $ bucketed
printSamples :: (Real a, Eq a, Ord a, Show a) => String -> [a] -> IO ()
printSamples s as = do
putStrLn $ "***" <> s
putStrLn $ " samples: " <> series2spark (map toRational as)
printHistogram :: [Float] -> IO ()
printHistogram samples = putStrLn $ series2spark (map fromIntegral . histogram 10 $ samples)
printCoin :: Float -> IO ()
printCoin bias = do
let g = mkStdGen 1
let (tosses, _) = samples 100 g (coin bias)
printSamples ("bias: " <> show bias) tosses
normal :: PL Float
normal = fromIntegral . sum <$> (replicateM 5 (coin 0.5))
main :: IO ()
main = do
printCoin 0.01
printCoin 0.99
printCoin 0.5
printCoin 0.7
putStrLn $ "normal distribution using central limit theorem: "
let g = mkStdGen 1
let (nsamples, _) = samples 1000 g normal
printHistogram nsamples
putStrLn $ "normal distribution using MCMC: "
let (mcmcsamples, _) = samples 1000 g (mhD $ normalD 0.5)
printHistogram mcmcsamples
putStrLn $ "sampling from x^4 with finite support"
let (mcmcsamples, _) = samples 1000 g (mhD $ polyD 4)
printHistogram mcmcsamples
§ Output
***bias: 1.0e-2
samples: ________________________________________█_█________
***bias: 0.99
samples: ███████████████████████████████████████████████████
***bias: 0.5
samples: __█____█__███_███_█__█_█___█_█_██___████████__█_███
***bias: 0.7
samples: __█__█_█__███_█████__███_█_█_█_██_█_████████__█████
normal distribution using central limit theorem:
_▄▇█▄_
normal distribution using MCMC:
__▁▄█▅▂▁___
sampling from x^4 with finite support
▁▁▃▃▃▄▅▆▇█_