module Control.Monad.Fix (
        MonadFix(mfix),
        fix
  ) where
import Data.Either
import Data.Function ( fix )
import Data.Maybe
import Data.Monoid ( Dual(..), Sum(..), Product(..)
                   , First(..), Last(..), Alt(..), Ap(..) )
import Data.Ord ( Down(..) )
import GHC.Base ( Monad, NonEmpty(..), errorWithoutStackTrace, (.) )
import GHC.Generics
import GHC.List ( head, tail )
import Control.Monad.ST.Imp
import System.IO
class (Monad m) => MonadFix m where
        
        
        
        
        mfix :: (a -> m a) -> m a
instance MonadFix Maybe where
    mfix f = let a = f (unJust a) in a
             where unJust (Just x) = x
                   unJust Nothing  = errorWithoutStackTrace "mfix Maybe: Nothing"
instance MonadFix [] where
    mfix f = case fix (f . head) of
               []    -> []
               (x:_) -> x : mfix (tail . f)
instance MonadFix NonEmpty where
  mfix f = case fix (f . neHead) of
             ~(x :| _) -> x :| mfix (neTail . f)
    where
      neHead ~(a :| _) = a
      neTail ~(_ :| as) = as
instance MonadFix IO where
    mfix = fixIO
instance MonadFix ((->) r) where
    mfix f = \ r -> let a = f a r in a
instance MonadFix (Either e) where
    mfix f = let a = f (unRight a) in a
             where unRight (Right x) = x
                   unRight (Left  _) = errorWithoutStackTrace "mfix Either: Left"
instance MonadFix (ST s) where
        mfix = fixST
instance MonadFix Dual where
    mfix f   = Dual (fix (getDual . f))
instance MonadFix Sum where
    mfix f   = Sum (fix (getSum . f))
instance MonadFix Product where
    mfix f   = Product (fix (getProduct . f))
instance MonadFix First where
    mfix f   = First (mfix (getFirst . f))
instance MonadFix Last where
    mfix f   = Last (mfix (getLast . f))
instance MonadFix f => MonadFix (Alt f) where
    mfix f   = Alt (mfix (getAlt . f))
instance MonadFix f => MonadFix (Ap f) where
    mfix f   = Ap (mfix (getAp . f))
instance MonadFix Par1 where
    mfix f = Par1 (fix (unPar1 . f))
instance MonadFix f => MonadFix (Rec1 f) where
    mfix f = Rec1 (mfix (unRec1 . f))
instance MonadFix f => MonadFix (M1 i c f) where
    mfix f = M1 (mfix (unM1. f))
instance (MonadFix f, MonadFix g) => MonadFix (f :*: g) where
    mfix f = (mfix (fstP . f)) :*: (mfix (sndP . f))
      where
        fstP (a :*: _) = a
        sndP (_ :*: b) = b
instance MonadFix Down where
    mfix f = Down (fix (getDown . f))
      where getDown (Down x) = x