从接受闭合集合作为输入的函数中导出一个优化的`case-of`表达式

huangapple go评论54阅读模式
英文:

Derive an optimized case-of expression from a function that takes a closed set as its input

问题

I have a closed set of values:

data Value = A | B | C | D | E ...
  deriving (Eq, Ord, Show)

And a data structure that represents their order:

order :: [[Value]]
order = [
  [ B ],
  [ A, D ],
  [ C ],
  ...
  ]

I need to convert a Value's order into an Int. I could do it like this:

prec' :: [[Value]] -> Value -> Int
prec' [] _ = 0
prec' (vs : rest) v = if v `elem` vs
  then 1 + length rest
  else prec' rest v

prec :: Value -> Int
prec = prec' order

However this prec has complexity O(n).

What I would want, is a very lightweight and optimized function like this one:

prec :: Value -> Int
prec = \case
  A -> 2
  B -> 3
  C -> 1
  D -> 2
  E -> 0
  ...

But of course I don't want to write it manually, otherwise it risks being inconsistent with the information stored in order. The Haskell compiler should be able to derive that function on its own easily, since its input is a closed set.

How can I get GHC to generate a function like the latest definition of prec?

英文:

I have a closed set of values:

data Value = A | B | C | D | E ...
  deriving (Eq, Ord, Show)

And a data structure that represents their order:

order :: [[Value]]
order = [
  [ B ],
  [ A, D ],
  [ C ],
  ...
  ]

I need to convert a Value's order into an Int. I could do it like this:

prec' :: [[Value]] -> Value -> Int
prec' [] _ = 0
prec' (vs : rest) v = if v `elem` vs
  then 1 + length rest
  else prec' rest v

prec :: Value -> Int
prec = prec' order

However this prec has complexity O(n).

What I would want, is a very lightweight and optimized function like this one:

prec :: Value -> Int
prec = \case
  A -> 2
  B -> 3
  C -> 1
  D -> 2
  E -> 0
  ...

But of course I don't want to write it manually, otherwise it risks being inconsistent with the information stored in order. The Haskell compiler should be able to derive that function on its own easily, since its input is a closed set.

How can I get GHC to generate a function like the latest definition of prec?

答案1

得分: 8

Solution 1: 使用模板Haskell来生成所需的代码。

Solution 2 (下面有扩展):(滥用)简化器。

简化的主要障碍是GHC不会内联递归函数。一种解决方法是通过类型类来执行递归。

-- 直观上 unroll :: Nat -> (a -> a) -> (a -> a)
-- 但nat现在是一个类型级参数。
class Unroll (n :: Nat) where
  unroll :: (a -> a) -> (a -> a)

instance Unroll 0 where
  unroll = id

instance {-# OVERLAPPABLE #-} Unroll (n-1) => Unroll n where
  unroll f = f . unroll @(n-1) f

这允许您定义以下的不动点运算符,展开前n次迭代:

unrollfix :: forall n a. Unroll n => (a -> a) -> a
unrollfix f = unroll @n f (fix f)

然后,您需要使用fix编写所有递归函数,并将fix替换为unrollfix。您还需要在适当的地方加入一些INLINE指令。

使用fixelem

elem :: forall a. Eq a => a -> [a] -> Bool
elem = fix go
  where
    go elem_ x [] = False
    go elem_ x (y : ys) = x == y || elem_ x ys

使用unrollfixelem

{-# INLINE uelem #-}
uelem :: forall n a. (Unroll n, Eq a) => a -> [a] -> Bool
uelem = unrollfix @n go
  where
    go elem_ x [] = False
    go elem_ x (y : ys) = x == y || elem_ x ys

还有length(省略)和prec'

使用fixprec'

prec' :: forall a. Eq a => [[a]] -> a -> Int
prec' = fix go
  where
    go prec_ [] v = 0
    go prec_ (vs : rest) v = if elem v vs
      then 1 + length rest
      else prec_ rest v

使用unrollfixprec'

prec' :: forall n a. (Unroll n, Eq a) => [[a]] -> a -> Int
prec' = unrollfix @n go
  where
    go prec_ [] v = 0
    go prec_ (vs : rest) v = if uelem @n v vs
      then 1 + ulength @n rest
      else prec_ rest v
    {-# INLINE go #-}

最后,将n参数设置为足够大以启用简化。

prec :: Value -> Int
prec v = prec' @5 order v

完整代码:

{-# LANGUAGE AllowAmbiguousTypes, DataKinds, MultiParamTypeClasses, ScopedTypeVariables, TypeApplications, UndecidableInstances #-}
{-# OPTIONS_GHC -ddump-simpl #-}
module A (Value(..), prec) where

import GHC.TypeNats
import Data.Function (fix)
import GHC.Exts

class Unroll (n :: Nat) where
  unroll :: (a -> a) -> (a -> a)

instance Unroll 0 where
  unroll = id

instance {-# OVERLAPPABLE #-} Unroll (n-1) => Unroll n where
  unroll f = f . unroll @(n-1) f

unrollfix :: forall n a. Unroll n => (a -> a) -> a
unrollfix f = unroll @n f (fix f)

data Value = A | B | C | D | E
  deriving Eq

order :: [[Value]]
order = [[A], [B, C], [D], [E]] 

{-# INLINE uelem #-}
uelem :: forall n a. (Unroll n, Eq a) => a -> [a] -> Bool
uelem = unrollfix @n go
  where
    go elem_ x [] = False
    go elem_ x (y : ys) = x == y || elem_ x ys
    {-# INLINE go #-}

{-# INLINE ulength #-}
ulength :: forall n a. Unroll n => [a] -> Int
ulength = unrollfix @n go
  where
    go length_ [] = 0
    go length_ (_ : xs) = 1 + length_ xs
    {-# INLINE go #-}

prec' :: forall n a. (Unroll n, Eq a) => [[a]] -> a -> Int
prec' = unrollfix @n go
  where
    {-# INLINE go #-}
    go prec_ [] v = 0
    go prec_ (vs : rest) v = if uelem @n v vs
      then 1 + ulength @n rest
      else prec_ rest v

prec :: Value -> Int
prec v = prec' @5 order v

生成的Core(使用-ddump-simpl选项)(查看展开,而不是主要定义):

\ (v_aQC [Occ=Once1!] :: Value) ->
                 case v_aQC of {
                   __DEFAULT -> GHC.Types.I# 3#;
                   A -> GHC.Types.I# 4#;
                   D -> GHC.Types.I# 2#;
                   E -> GHC.Types.I# 1#
                 }
英文:

Solution 1: Use Template Haskell to generate the code you want.

Solution 2 (expanded below): (Ab)use the simplifier.

The main obstacle to simplification is that GHC will not inline recursive functions. One workaround is to do the recursion through type classes.

-- Intuitively unroll :: Nat -> (a -> a) -> (a -> a)
-- but the nat is now a type-level parameter.
class Unroll (n :: Nat) where
unroll :: (a -> a) -> (a -> a)
instance Unroll 0 where
unroll = id
instance {-# OVERLAPPABLE #-} Unroll (n-1) => Unroll n where
unroll f = f . unroll @(n-1) f

This lets you define the following fixpoint operator that unfolds the first n iterations:

unrollfix :: forall n a. Unroll n => (a -> a) -> a
unrollfix f = unroll @n f (fix f)

You then need to write all recursive functions using fix, and replace fix with unrollfix. You have to sprinkle some INLINE pragmas around too.

elem with fix:

elem :: forall a. Eq a => a -> [a] -> Bool
elem = fix go
where
go elem_ x [] = False
go elem_ x (y : ys) = x == y || elem_ x ys

elem with unrollfix:

{-# INLINE uelem #-}
uelem :: forall n a. (Unroll n, Eq a) => a -> [a] -> Bool
uelem = unrollfix @n go
where
go elem_ x [] = False
go elem_ x (y : ys) = x == y || elem_ x ys

Also length (omitted), and prec'.

prec' with fix:

prec' :: forall a. Eq a => [[a]] -> a -> Int
prec' = fix go
where
go prec_ [] v = 0
go prec_ (vs : rest) v = if elem v vs
then 1 + length rest
else prec_ rest v

prec' with unrollfix:

prec' :: forall n a. (Unroll n, Eq a) => [[a]] -> a -> Int
prec' = unrollfix @n go
where
go prec_ [] v = 0
go prec_ (vs : rest) v = if uelem @n v vs
then 1 + ulength @n rest
else prec_ rest v
{-# INLINE go #-}

Finally, set the n parameter to a high enough value to enable simplification.

prec :: Value -> Int
prec v = prec' @5 order v

Full code:

{-# LANGUAGE AllowAmbiguousTypes, DataKinds, MultiParamTypeClasses, ScopedTypeVariables, TypeApplications, UndecidableInstances #-}
{-# OPTIONS_GHC -ddump-simpl #-}
module A (Value(..), prec) where
import GHC.TypeNats
import Data.Function (fix)
import GHC.Exts
class Unroll (n :: Nat) where
unroll :: (a -> a) -> (a -> a)
instance Unroll 0 where
unroll = id
instance {-# OVERLAPPABLE #-} Unroll (n-1) => Unroll n where
unroll f = f . unroll @(n-1) f
unrollfix :: forall n a. Unroll n => (a -> a) -> a
unrollfix f = unroll @n f (fix f)
data Value = A | B | C | D | E
deriving Eq
order :: [[Value]]
order = [[A], [B, C], [D], [E]] 
{-# INLINE uelem #-}
uelem :: forall n a. (Unroll n, Eq a) => a -> [a] -> Bool
uelem = unrollfix @n go
where
go elem_ x [] = False
go elem_ x (y : ys) = x == y || elem_ x ys
{-# INLINE go #-}
{-# INLINE ulength #-}
ulength :: forall n a. Unroll n => [a] -> Int
ulength = unrollfix @n go
where
go length_ [] = 0
go length_ (_ : xs) = 1 + length_ xs
{-# INLINE go #-}
prec' :: forall n a. (Unroll n, Eq a) => [[a]] -> a -> Int
prec' = unrollfix @n go
where
{-# INLINE go #-}
go prec_ [] v = 0
go prec_ (vs : rest) v = if uelem @n v vs
then 1 + ulength @n rest
else prec_ rest v
prec :: Value -> Int
prec v = prec' @5 order v

Generated Core (using the -ddump-simpl option) (look at the unfolding, instead of the main definition):

\ (v_aQC [Occ=Once1!] :: Value) ->
case v_aQC of {
__DEFAULT -> GHC.Types.I# 3#;
A -> GHC.Types.I# 4#;
D -> GHC.Types.I# 2#;
E -> GHC.Types.I# 1#
}

答案2

得分: 4

我会简单地围绕该函数使用一个通用记忆化工具,比如MemoTriememoizefastmemo

{-# LANGUAGE DeriveGeneric, DeriveAnyClass #-}

import Data.Function.FastMemo

data Value = A | B | C | D | E ...
  deriving (Eq, Ord, Show, Generic, Memoizable)

prec :: Value -> Int
prec = memoize $ prec' order

这可能不如直接的 TH 解决方案优化,但由 Generic 生成的 Memoizable 实例应该产生一个相当相似的结果。不确定这些包中哪个做得最好。

英文:

I would just wrap a general-purpose memoization tool around the function, like MemoTrie or memoize or fastmemo.

{-# LANGUAGE DeriveGeneric, DeriveAnyClass #-}
import Data.Function.FastMemo
data Value = A | B | C | D | E ...
deriving (Eq, Ord, Show, Generic, Memoizable)
prec :: Value -> Int
prec = memoize $ prec' order

This may not be as optimized as a direct TH solution, but the Generic-derived Memoizable instance should result something reasonably similar. Not sure which of these packages do it best.

答案3

得分: 4

定义prec,然后使用prec生成order

prec :: Value -> Int
prec = \case
  A -> 2
  B -> 3
  C -> 1
  D -> 2
  E -> 0

order :: [[Value]]
order = go [A, B, C, D, E]
  where eqPrec = (==) `on` prec
        ordPrec = compare `on` prec
        go = reverse . groupBy eqPrec . sortBy ordPrec

请注意,我保留了代码中的英文标识符,只对注释进行了翻译。

英文:

Define prec first, then generate order using prec.

prec :: Value -> Int
prec = \case
A -> 2
B -> 3
C -> 1
D -> 2
E -> 0
order :: [[Value]]
order = go [A, B, C, D, E]
where eqPrec = (==) `on` prec
ordPrec = compare `on` prec
go = reverse . groupBy eqPrec . sortBy ordPrec

答案4

得分: 3

以下是翻译好的部分:

也许一个简单的解决方案是手动执行你希望编译器在拥有你描述的case语句后自动执行的操作 - 计算一个跳转表。

import Data.Array

-- 使用 Enum 衍生使得编译器生成类似于你所期望的 `case` 语句
data Foo = A | B | C | D | E deriving (Bounded, Enum)

orderArray :: Array Int Int
orderArray = listArray
    (0, fromEnum (maxBound :: Foo) - 1)
    (orderSlow <$> [minBound..maxBound])

prec :: Foo -> Int
prec = unsafeAt orderArray . fromEnum

这将需要对每个可能的值运行orderSlow一次,但第二次访问将是快速的O(1)查找。

英文:

Perhaps one simple solution would be to do by hand what you want the compiler to do automatically once it has the case statement you describe -- compute a jump table.

import Data.Array
-- deriving Enum makes a compiler-written case statement like what you want
data Foo = A | B | C | D | E deriving (Bounded, Enum)
orderArray :: Array Int Int
orderArray = listArray
(0, fromEnum (maxBound :: Foo) - 1)
(orderSlow &lt;$&gt; [minBound..maxBound])
prec :: Foo -&gt; Int
prec = unsafeAt orderArray . fromEnum

This will have to run orderSlow once for each possible value, but second accesses will be fast O(1) lookups.

答案5

得分: 2

将你的 prec' 函数使用 _fold_ 重写,这种情况下使用 foldl',并为 prec'order 添加 INLINE pragma。

import Data.Foldable
import Data.Maybe

data Value = A | B | C | D | E
  deriving (Eq)

prec' :: Value -> [[Value]] -> Int
prec' v = fromMaybe 0 . foldl' f Nothing
  where
    f (Just !len) _ = Just (len + 1)
    f Nothing vs | v `elem` vs = Just 1
                 | otherwise = Nothing
{-# INLINE prec' #-}

order :: [[Value]]
order = [
  [ B ],
  [ A, D ],
  [ C ]
  ]
{-# INLINE order #-}
  
prec :: Value -> Int
prec v = prec' v order

使用 GHC 9.6 编译此代码后,我得到以下简化器输出(使用 -ddump-simpl 和一些 -dsuppress-* 选项):

prec1 :: Int
prec1 = I# 0#

prec2 :: Int
prec2 = I# 1#

-- snip --

lvl1 :: Int
lvl1 = I# 2#

lvl2 :: Int
lvl2 = I# 3#

prec :: Value -> Int
prec
  = \ (v :: Value) ->
      case v of {
        __DEFAULT -> lvl1;
        B -> lvl2;
        C -> prec2;
        E -> prec1
      }

为什么这个代码能工作?

当你编写列表文字时,GHC 实际上会使用 build 函数进行它的宏展开,这意味着如果你使用一个 fold 函数,你的函数有可能参与列表融合。例如,如果我使用 -ddump-ds(将 GHC Core 输出为 GHC Core)编译你的 order,我得到以下结果:

order
  = build
      (\ (@a_d1Pf)
         (c_d1Pg :: [Value] -> a_d1Pf -> a_d1Pf)
         (n_d1Ph :: a_d1Pf) ->
         c_d1Pg
           (build
              (\ (@a_d1P4)
                 (c_d1P5 :: Value -> a_d1P4 -> a_d1P4)
                 (n_d1P6 :: a_d1P4) ->
                 c_d1P5 B n_d1P6))
           (c_d1Pg
              (build
                 (\ (@a_d1P9)
                    (c_d1Pa :: Value -> a_d1P9 -> a_d1P9)
                    (n_d1Pb :: a_d1P9) ->
                    c_d1Pa A (c_d1Pa D n_d1Pb)))
              (c_d1Pg
                 (build
                    (\ (@a_d1Pc)
                       (c_d1Pd :: Value -> a_d1Pc -> a_d1Pc)
                       (n_d1Pe :: a_d1Pc) ->
                       c_d1Pd C n_d1Pe))
                 n_d1Ph)))

因此,通过足够的内联和事实上,elemfoldl' 都参与列表融合作为好的消费者,简化器可以使用重写规则积极地优化你的函数。

英文:

Rewrite your prec&#39; function using a fold, in this case foldl&#39;, and add INLINE pragmas to prec&#39; and order.

import Data.Foldable
import Data.Maybe

data Value = A | B | C | D | E
  deriving (Eq)

prec&#39; :: Value -&gt; [[Value]] -&gt; Int
prec&#39; v = fromMaybe 0 . foldl&#39; f Nothing
  where
    f (Just !len) _ = Just (len + 1)
    f Nothing vs | v `elem` vs = Just 1
                 | otherwise = Nothing
{-# INLINE prec&#39; #-}

order :: [[Value]]
order = [
  [ B ],
  [ A, D ],
  [ C ]
  ]
{-# INLINE order #-}
  
prec :: Value -&gt; Int
prec v = prec&#39; v order

Compiling this with GHC 9.6 I get the following simplifier output (using -ddump-simpl and some -dsuppress-* options):

prec1 :: Int
prec1 = I# 0#

prec2 :: Int
prec2 = I# 1#

-- snip --

lvl1 :: Int
lvl1 = I# 2#

lvl2 :: Int
lvl2 = I# 3#

prec :: Value -&gt; Int
prec
  = \ (v :: Value) -&gt;
      case v of {
        __DEFAULT -&gt; lvl1;
        B -&gt; lvl2;
        C -&gt; prec2;
        E -&gt; prec1
      }

Why does this work?

When you write a list literal, GHC will actually desugar it using the build function, which means that if you use a fold then your function can potentially participate in list fusion. For example, if I compile your order with -ddump-ds (which dumps the desugaring output as GHC Core) then I get the following result:

order
  = build
      (\ (@a_d1Pf)
         (c_d1Pg :: [Value] -&gt; a_d1Pf -&gt; a_d1Pf)
         (n_d1Ph :: a_d1Pf) -&gt;
         c_d1Pg
           (build
              (\ (@a_d1P4)
                 (c_d1P5 :: Value -&gt; a_d1P4 -&gt; a_d1P4)
                 (n_d1P6 :: a_d1P4) -&gt;
                 c_d1P5 B n_d1P6))
           (c_d1Pg
              (build
                 (\ (@a_d1P9)
                    (c_d1Pa :: Value -&gt; a_d1P9 -&gt; a_d1P9)
                    (n_d1Pb :: a_d1P9) -&gt;
                    c_d1Pa A (c_d1Pa D n_d1Pb)))
              (c_d1Pg
                 (build
                    (\ (@a_d1Pc)
                       (c_d1Pd :: Value -&gt; a_d1Pc -&gt; a_d1Pc)
                       (n_d1Pe :: a_d1Pc) -&gt;
                       c_d1Pd C n_d1Pe))
                 n_d1Ph)))

So with enough inlining and the fact that both elem and foldl&#39; participate in list fusion as good consumers, the simplifier can optimize your function aggressively using rewrite rules.

答案6

得分: 1

以下是一个TH解决方案(template-haskell-2.19.0):

{-# LANGUAGE TemplateHaskell #-}

module PrecTH where

import Language.Haskell.TH

import Data.List (nub)

prec' :: Ord a => [[a]] -> a -> Int
prec' [] _ = 0
prec' (vs : rest) v = if v `elem` vs
  then 1 + length rest
  else prec' rest v

mkPrecValueDataType :: String -> [[String]] -> DecsQ
mkPrecValueDataType dtName order = pure [
     DataD [] dtName' [] Nothing
       [ NormalC (mkName c) []
       | c <- cstrs ]
       [DerivClause Nothing $ ConT <$> [''Eq, ''Ord, ''Show]]
   , SigD (mkName "prec")
      $ ArrowT `AppT` ConT dtName' `AppT` ConT ''Int
   , FunD (mkName "prec")
       [ Clause [ConP (mkName c) [] []]
                (NormalB . LitE . IntegerL $ slowPrec c)
                []
       | c <- cstrs ]
   ]
 where slowPrec = fromIntegral . prec' order
       cstrs = concat order -- apply `sort` here if you like the constructors in alphabetical order
       dtName' = mkName dtName

要这样使用:

{-# LANGUAGE TemplateHaskell #-}

module PrecValues where

import PrecTH

mkPrecValueDataType "Value" [["B"], ["A","D"], ["C"]]

并产生以下输出:

ghci> :browse
prec :: Value -> Int
type Value :: *
data Value = B | A | D | C
ghci> prec<$>[A,B,C,D]
[2,3,1,2]
英文:

For completeness, here is a TH solution (template-haskell-2.19.0):

{-# LANGUAGE TemplateHaskell #-}
module PrecTH where
import Language.Haskell.TH
import Data.List (nub)
prec&#39; :: Ord a =&gt; [[a]] -&gt; a -&gt; Int
prec&#39; [] _ = 0
prec&#39; (vs : rest) v = if v `elem` vs
then 1 + length rest
else prec&#39; rest v
mkPrecValueDataType :: String -&gt; [[String]] -&gt; DecsQ
mkPrecValueDataType dtName order = pure [
DataD [] dtName&#39; [] Nothing
[ NormalC (mkName c) []
| c &lt;- cstrs ]
[DerivClause Nothing $ ConT &lt;$&gt; [&#39;&#39;Eq, &#39;&#39;Ord, &#39;&#39;Show ]]
, SigD (mkName &quot;prec&quot;)
$ ArrowT `AppT` ConT dtName&#39; `AppT` ConT &#39;&#39;Int
, FunD (mkName &quot;prec&quot;)
[ Clause [ConP (mkName c) [] []]
(NormalB . LitE . IntegerL $ slowPrec c)
[]
| c &lt;- cstrs ]
]
where slowPrec = fromIntegral . prec&#39; order
cstrs = concat order -- apply `sort` here if you like the constructors in alphabetical order
dtName&#39; = mkName dtName

To be used thus

{-# LANGUAGE TemplateHaskell #-}
module PrecValues where
import PrecTH
mkPrecValueDataType &quot;Value&quot; [[&quot;B&quot;], [&quot;A&quot;,&quot;D&quot;], [&quot;C&quot;]]

and producing

ghci&gt; :browse
prec :: Value -&gt; Int
type Value :: *
data Value = B | A | D | C
ghci&gt; prec&lt;$&gt;[A,B,C,D]
[2,3,1,2]

huangapple
  • 本文由 发表于 2023年2月8日 19:11:36
  • 转载请务必保留本文链接:https://go.coder-hub.com/75384937.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定