Monday, May 15, 2017

128-bit AES electronic codebook

Rijndael is an algorithm that might actually be more succinct in C than Haskell, but I always wanted to learn the details of AES, and writing it in C wouldn't be nearly as fun. My thanks to Jeff Moser and Sam Trenholme for their excellent elucidations.

Note that this implementation is ECB mode, it doesn't include any decryption code, it computes rather than hard-codes the S-box, and it's probably vulnerable to side-channel attacks - so of course it's neither intended nor safe for production use.

{-# LANGUAGE NoMonomorphismRestriction #-}

import Control.Applicative (liftA2)
import Data.Bits (xor, shiftL, shiftR, (.|.), (.&.))
import Data.List (transpose, sortBy, foldl')
import Data.Ord (comparing)
import Data.Word (Word8)

encrypt input key = last ks `g` sRows (h t)
 where
  t  = foldl1 (g . f) $ init (k : tail ks)
  f  = transpose . map mix . transpose . sRows . h
  g  = zipWith $ zipWith xor
  h  = map $ map sub
  k  = input `g` head ks
  ks = expand key

mix [a,b,c,d] = [a', b', c', d']
 where
  a' = w ⊕ d ⊕ c ⊕ x ⊕ b
  b' = x ⊕ a ⊕ d ⊕ y ⊕ c
  c' = y ⊕ b ⊕ a ⊕ z ⊕ d
  d' = z ⊕ c ⊕ b ⊕ w ⊕ a 
  [w,x,y,z] = map fg [a,b,c,d]

fg b = b''
 where
  b'  = shiftL b 1
  b'' = ((b .&. 0x80) == 0x80) ? (b' ⊕ 0x1B, b')

sRows (w:ws) = w : zipWith f ws [1,2,3]
 where
  f w i = take 4 $ drop i $ cycle w

-----------------------------------------------------------
expand k = scanl f k [1, 2, 4, 8, 16, 32, 64, 128, 27, 54]
 where
  f n w = xpndE (transpose n) . xpndC . xpndB . xpndA $ xpnd0 w n

xpndE n [a,b,c,_] = transpose [a, b, c, zipWith xor c $ last n]

xpndC [a,b,c,d] = [a, b, zipWith xor b c, d]
xpndB [a,b,c,d] = [a, zipWith xor a b, c, d]
xpndA [a,b,c,d] = zipWith xor a d : [b,c,d]

xpnd0 rc ws = take 3 tW ++ [w']
 where
  w' = zipWith xor (map sub w) [rc, 0, 0, 0]
  tW = transpose ws
  w  = take 4 $ tail $ cycle $ last tW

----------------------------------------------------
sub w = get sbox (fromIntegral lo) $ fromIntegral hi
 where
  (hi, lo) = nibs w

nibs w    = (shiftR (w .&. 0xF0) 4, w .&. 0x0F)
(⊕)      = xor
p ? (a,b) = if p then a else b; infix 2 ?

get wss x y = (wss !! y) !! x

----------------------------------------------------
sbox = grid 16 $ map snd $ sortBy (comparing fst) $ sbx 1 1 []

sbx :: Word8 -> Word8 -> [(Word8, Word8)] -> [(Word8, Word8)]
sbx p q ws
  | length ws == 255 = (0, 0x63) : ws
  | otherwise = sbx p' r $ (p', xf ⊕ 0x63) : ws
 where
  p' = p ⊕ shiftL p 1 ⊕ ((p .&. 0x80 /= 0) ? (0x1B, 0))
  q1 = foldl' (liftA2 (.) xor shiftL) q [1, 2, 4]
  r  = q1 ⊕ ((q1 .&. 0x80 /= 0) ? (0x09, 0))
  xf = r  ⊕ rotl8 r 1 ⊕ rotl8 r 2 ⊕ rotl8 r 3 ⊕ rotl8 r 4

grid _ [] = []
grid n xs = take n xs : grid n (drop n xs)

rotl8 w n = (w `shiftL` n) .|. (w `shiftR` (8 - n))