{-# LANGUAGE GeneralizedNewtypeDeriving #-}

-- |
-- Module      : Crypto.PubKey.ElGamal
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good
--
-- This module is a work in progress. do not use:
-- it might eat your dog, your data or even both.
--
-- TODO: provide a mapping between integer and ciphertext
--       generate numbers correctly
module Crypto.PubKey.ElGamal (
    Params,
    PublicNumber,
    PrivateNumber,
    EphemeralKey (..),
    SharedKey,
    Signature,

    -- * Generation
    generatePrivate,
    generatePublic,

    -- * Encryption and decryption with no scheme
    encryptWith,
    encrypt,
    decrypt,

    -- * Signature primitives
    signWith,
    sign,

    -- * Verification primitives
    verify,
) where

import Crypto.Hash
import Crypto.Internal.ByteArray (ByteArrayAccess)
import Crypto.Internal.Imports
import Crypto.Number.Basic (gcde)
import Crypto.Number.Generate (generateMax)
import Crypto.Number.ModArithmetic (expFast, expSafe, inverse)
import Crypto.Number.Serialize (os2ip)
import Crypto.PubKey.DH (
    Params (..),
    PrivateNumber (..),
    PublicNumber (..),
    SharedKey (..),
 )
import Crypto.Random.Types
import Data.Maybe (fromJust)

-- | ElGamal Signature
data Signature = Signature (Integer, Integer)

-- | ElGamal Ephemeral key. also called Temporary key.
newtype EphemeralKey = EphemeralKey Integer
    deriving (EphemeralKey -> ()
(EphemeralKey -> ()) -> NFData EphemeralKey
forall a. (a -> ()) -> NFData a
$crnf :: EphemeralKey -> ()
rnf :: EphemeralKey -> ()
NFData)

-- | generate a private number with no specific property
-- this number is usually called a and need to be between
-- 0 and q (order of the group G).
generatePrivate :: MonadRandom m => Integer -> m PrivateNumber
generatePrivate :: forall (m :: * -> *). MonadRandom m => Integer -> m PrivateNumber
generatePrivate Integer
q = Integer -> PrivateNumber
PrivateNumber (Integer -> PrivateNumber) -> m Integer -> m PrivateNumber
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer -> m Integer
forall (m :: * -> *). MonadRandom m => Integer -> m Integer
generateMax Integer
q

-- | generate an ephemeral key which is a number with no specific property,
-- and need to be between 0 and q (order of the group G).
generateEphemeral :: MonadRandom m => Integer -> m EphemeralKey
generateEphemeral :: forall (m :: * -> *). MonadRandom m => Integer -> m EphemeralKey
generateEphemeral Integer
q = PrivateNumber -> EphemeralKey
toEphemeral (PrivateNumber -> EphemeralKey)
-> m PrivateNumber -> m EphemeralKey
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer -> m PrivateNumber
forall (m :: * -> *). MonadRandom m => Integer -> m PrivateNumber
generatePrivate Integer
q
  where
    toEphemeral :: PrivateNumber -> EphemeralKey
toEphemeral (PrivateNumber Integer
n) = Integer -> EphemeralKey
EphemeralKey Integer
n

-- | generate a public number that is for the other party benefits.
-- this number is usually called h=g^a
generatePublic :: Params -> PrivateNumber -> PublicNumber
generatePublic :: Params -> PrivateNumber -> PublicNumber
generatePublic (Params Integer
p Integer
g Int
_) (PrivateNumber Integer
a) = Integer -> PublicNumber
PublicNumber (Integer -> PublicNumber) -> Integer -> PublicNumber
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer -> Integer
expSafe Integer
g Integer
a Integer
p

-- | encrypt with a specified ephemeral key
-- do not reuse ephemeral key.
encryptWith
    :: EphemeralKey -> Params -> PublicNumber -> Integer -> (Integer, Integer)
encryptWith :: EphemeralKey
-> Params -> PublicNumber -> Integer -> (Integer, Integer)
encryptWith (EphemeralKey Integer
b) (Params Integer
p Integer
g Int
_) (PublicNumber Integer
h) Integer
m = (Integer
c1, Integer
c2)
  where
    s :: Integer
s = Integer -> Integer -> Integer -> Integer
expSafe Integer
h Integer
b Integer
p
    c1 :: Integer
c1 = Integer -> Integer -> Integer -> Integer
expSafe Integer
g Integer
b Integer
p
    c2 :: Integer
c2 = (Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
m) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p

-- | encrypt a message using params and public keys
-- will generate b (called the ephemeral key)
encrypt
    :: MonadRandom m => Params -> PublicNumber -> Integer -> m (Integer, Integer)
encrypt :: forall (m :: * -> *).
MonadRandom m =>
Params -> PublicNumber -> Integer -> m (Integer, Integer)
encrypt params :: Params
params@(Params Integer
p Integer
_ Int
_) PublicNumber
public Integer
m = (\EphemeralKey
b -> EphemeralKey
-> Params -> PublicNumber -> Integer -> (Integer, Integer)
encryptWith EphemeralKey
b Params
params PublicNumber
public Integer
m) (EphemeralKey -> (Integer, Integer))
-> m EphemeralKey -> m (Integer, Integer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer -> m EphemeralKey
forall (m :: * -> *). MonadRandom m => Integer -> m EphemeralKey
generateEphemeral Integer
q
  where
    q :: Integer
q = Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1 -- p is prime, hence order of the group is p-1

-- | decrypt message
decrypt :: Params -> PrivateNumber -> (Integer, Integer) -> Integer
decrypt :: Params -> PrivateNumber -> (Integer, Integer) -> Integer
decrypt (Params Integer
p Integer
_ Int
_) (PrivateNumber Integer
a) (Integer
c1, Integer
c2) = (Integer
c2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
sm1) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p
  where
    s :: Integer
s = Integer -> Integer -> Integer -> Integer
expSafe Integer
c1 Integer
a Integer
p
    sm1 :: Integer
sm1 = Maybe Integer -> Integer
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Integer -> Integer) -> Maybe Integer -> Integer
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Maybe Integer
inverse Integer
s Integer
p -- always inversible in Zp

-- | sign a message with an explicit k number
--
-- if k is not appropriate, then no signature is returned.
--
-- with some appropriate value of k, the signature generation can fail,
-- and no signature is returned. User of this function need to retry
-- with a different k value.
signWith
    :: (ByteArrayAccess msg, HashAlgorithm hash)
    => Integer
    -- ^ random number k, between 0 and p-1 and gcd(k,p-1)=1
    -> Params
    -- ^ DH params (p,g)
    -> PrivateNumber
    -- ^ DH private key
    -> hash
    -- ^ collision resistant hash algorithm
    -> msg
    -- ^ message to sign
    -> Maybe Signature
signWith :: forall msg hash.
(ByteArrayAccess msg, HashAlgorithm hash) =>
Integer
-> Params -> PrivateNumber -> hash -> msg -> Maybe Signature
signWith Integer
k (Params Integer
p Integer
g Int
_) (PrivateNumber Integer
x) hash
hashAlg msg
msg
    | Integer
k Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1 Bool -> Bool -> Bool
|| Integer
d Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
1 = Maybe Signature
forall a. Maybe a
Nothing -- gcd(k,p-1) is not 1
    | Integer
s Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = Maybe Signature
forall a. Maybe a
Nothing
    | Bool
otherwise = Signature -> Maybe Signature
forall a. a -> Maybe a
Just (Signature -> Maybe Signature) -> Signature -> Maybe Signature
forall a b. (a -> b) -> a -> b
$ (Integer, Integer) -> Signature
Signature (Integer
r, Integer
s)
  where
    r :: Integer
r = Integer -> Integer -> Integer -> Integer
expSafe Integer
g Integer
k Integer
p
    h :: Integer
h = Digest hash -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip (Digest hash -> Integer) -> Digest hash -> Integer
forall a b. (a -> b) -> a -> b
$ hash -> msg -> Digest hash
forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith hash
hashAlg msg
msg
    s :: Integer
s = ((Integer
h Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
r) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
kInv) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` (Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)
    (Integer
kInv, Integer
_, Integer
d) = Integer -> Integer -> (Integer, Integer, Integer)
gcde Integer
k (Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)

-- | sign message
--
-- This function will generate a random number, however
-- as the signature might fail, the function will automatically retry
-- until a proper signature has been created.
sign
    :: (ByteArrayAccess msg, HashAlgorithm hash, MonadRandom m)
    => Params
    -- ^ DH params (p,g)
    -> PrivateNumber
    -- ^ DH private key
    -> hash
    -- ^ collision resistant hash algorithm
    -> msg
    -- ^ message to sign
    -> m Signature
sign :: forall msg hash (m :: * -> *).
(ByteArrayAccess msg, HashAlgorithm hash, MonadRandom m) =>
Params -> PrivateNumber -> hash -> msg -> m Signature
sign params :: Params
params@(Params Integer
p Integer
_ Int
_) PrivateNumber
priv hash
hashAlg msg
msg = do
    Integer
k <- Integer -> m Integer
forall (m :: * -> *). MonadRandom m => Integer -> m Integer
generateMax (Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)
    case Integer
-> Params -> PrivateNumber -> hash -> msg -> Maybe Signature
forall msg hash.
(ByteArrayAccess msg, HashAlgorithm hash) =>
Integer
-> Params -> PrivateNumber -> hash -> msg -> Maybe Signature
signWith Integer
k Params
params PrivateNumber
priv hash
hashAlg msg
msg of
        Maybe Signature
Nothing -> Params -> PrivateNumber -> hash -> msg -> m Signature
forall msg hash (m :: * -> *).
(ByteArrayAccess msg, HashAlgorithm hash, MonadRandom m) =>
Params -> PrivateNumber -> hash -> msg -> m Signature
sign Params
params PrivateNumber
priv hash
hashAlg msg
msg
        Just Signature
sig -> Signature -> m Signature
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Signature
sig

-- | verify a signature
verify
    :: (ByteArrayAccess msg, HashAlgorithm hash)
    => Params
    -> PublicNumber
    -> hash
    -> msg
    -> Signature
    -> Bool
verify :: forall msg hash.
(ByteArrayAccess msg, HashAlgorithm hash) =>
Params -> PublicNumber -> hash -> msg -> Signature -> Bool
verify (Params Integer
p Integer
g Int
_) (PublicNumber Integer
y) hash
hashAlg msg
msg (Signature (Integer
r, Integer
s))
    | [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Integer
r Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
0, Integer
r Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
p, Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
0, Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= (Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)] = Bool
False
    | Bool
otherwise = Integer
lhs Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
rhs
  where
    h :: Integer
h = Digest hash -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip (Digest hash -> Integer) -> Digest hash -> Integer
forall a b. (a -> b) -> a -> b
$ hash -> msg -> Digest hash
forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith hash
hashAlg msg
msg
    lhs :: Integer
lhs = Integer -> Integer -> Integer -> Integer
expFast Integer
g Integer
h Integer
p
    rhs :: Integer
rhs = (Integer -> Integer -> Integer -> Integer
expFast Integer
y Integer
r Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer -> Integer -> Integer -> Integer
expFast Integer
r Integer
s Integer
p) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p