英文:
Derive an optimized case-of expression from a function that takes a closed set as its input
问题
I have a closed set of values:
data Value = A | B | C | D | E ...
deriving (Eq, Ord, Show)
And a data structure that represents their order:
order :: [[Value]]
order = [
[ B ],
[ A, D ],
[ C ],
...
]
I need to convert a Value's order into an Int
. I could do it like this:
prec' :: [[Value]] -> Value -> Int
prec' [] _ = 0
prec' (vs : rest) v = if v `elem` vs
then 1 + length rest
else prec' rest v
prec :: Value -> Int
prec = prec' order
However this prec
has complexity O(n).
What I would want, is a very lightweight and optimized function like this one:
prec :: Value -> Int
prec = \case
A -> 2
B -> 3
C -> 1
D -> 2
E -> 0
...
But of course I don't want to write it manually, otherwise it risks being inconsistent with the information stored in order
. The Haskell compiler should be able to derive that function on its own easily, since its input is a closed set.
How can I get GHC to generate a function like the latest definition of prec
?
英文:
I have a closed set of values:
data Value = A | B | C | D | E ...
deriving (Eq, Ord, Show)
And a data structure that represents their order:
order :: [[Value]]
order = [
[ B ],
[ A, D ],
[ C ],
...
]
I need to convert a Value's order into an Int
. I could do it like this:
prec' :: [[Value]] -> Value -> Int
prec' [] _ = 0
prec' (vs : rest) v = if v `elem` vs
then 1 + length rest
else prec' rest v
prec :: Value -> Int
prec = prec' order
However this prec
has complexity O(n).
What I would want, is a very lightweight and optimized function like this one:
prec :: Value -> Int
prec = \case
A -> 2
B -> 3
C -> 1
D -> 2
E -> 0
...
But of course I don't want to write it manually, otherwise it risks being inconsistent with the information stored in order
. The Haskell compiler should be able to derive that function on its own easily, since its input is a closed set.
How can I get GHC to generate a function like the latest definition of prec
?
答案1
得分: 8
Solution 1: 使用模板Haskell来生成所需的代码。
Solution 2 (下面有扩展):(滥用)简化器。
简化的主要障碍是GHC不会内联递归函数。一种解决方法是通过类型类来执行递归。
-- 直观上 unroll :: Nat -> (a -> a) -> (a -> a)
-- 但nat现在是一个类型级参数。
class Unroll (n :: Nat) where
unroll :: (a -> a) -> (a -> a)
instance Unroll 0 where
unroll = id
instance {-# OVERLAPPABLE #-} Unroll (n-1) => Unroll n where
unroll f = f . unroll @(n-1) f
这允许您定义以下的不动点运算符,展开前n次迭代:
unrollfix :: forall n a. Unroll n => (a -> a) -> a
unrollfix f = unroll @n f (fix f)
然后,您需要使用fix
编写所有递归函数,并将fix
替换为unrollfix
。您还需要在适当的地方加入一些INLINE
指令。
使用fix
的elem
:
elem :: forall a. Eq a => a -> [a] -> Bool
elem = fix go
where
go elem_ x [] = False
go elem_ x (y : ys) = x == y || elem_ x ys
使用unrollfix
的elem
:
{-# INLINE uelem #-}
uelem :: forall n a. (Unroll n, Eq a) => a -> [a] -> Bool
uelem = unrollfix @n go
where
go elem_ x [] = False
go elem_ x (y : ys) = x == y || elem_ x ys
还有length
(省略)和prec'
。
使用fix
的prec'
:
prec' :: forall a. Eq a => [[a]] -> a -> Int
prec' = fix go
where
go prec_ [] v = 0
go prec_ (vs : rest) v = if elem v vs
then 1 + length rest
else prec_ rest v
使用unrollfix
的prec'
:
prec' :: forall n a. (Unroll n, Eq a) => [[a]] -> a -> Int
prec' = unrollfix @n go
where
go prec_ [] v = 0
go prec_ (vs : rest) v = if uelem @n v vs
then 1 + ulength @n rest
else prec_ rest v
{-# INLINE go #-}
最后,将n
参数设置为足够大以启用简化。
prec :: Value -> Int
prec v = prec' @5 order v
完整代码:
{-# LANGUAGE AllowAmbiguousTypes, DataKinds, MultiParamTypeClasses, ScopedTypeVariables, TypeApplications, UndecidableInstances #-}
{-# OPTIONS_GHC -ddump-simpl #-}
module A (Value(..), prec) where
import GHC.TypeNats
import Data.Function (fix)
import GHC.Exts
class Unroll (n :: Nat) where
unroll :: (a -> a) -> (a -> a)
instance Unroll 0 where
unroll = id
instance {-# OVERLAPPABLE #-} Unroll (n-1) => Unroll n where
unroll f = f . unroll @(n-1) f
unrollfix :: forall n a. Unroll n => (a -> a) -> a
unrollfix f = unroll @n f (fix f)
data Value = A | B | C | D | E
deriving Eq
order :: [[Value]]
order = [[A], [B, C], [D], [E]]
{-# INLINE uelem #-}
uelem :: forall n a. (Unroll n, Eq a) => a -> [a] -> Bool
uelem = unrollfix @n go
where
go elem_ x [] = False
go elem_ x (y : ys) = x == y || elem_ x ys
{-# INLINE go #-}
{-# INLINE ulength #-}
ulength :: forall n a. Unroll n => [a] -> Int
ulength = unrollfix @n go
where
go length_ [] = 0
go length_ (_ : xs) = 1 + length_ xs
{-# INLINE go #-}
prec' :: forall n a. (Unroll n, Eq a) => [[a]] -> a -> Int
prec' = unrollfix @n go
where
{-# INLINE go #-}
go prec_ [] v = 0
go prec_ (vs : rest) v = if uelem @n v vs
then 1 + ulength @n rest
else prec_ rest v
prec :: Value -> Int
prec v = prec' @5 order v
生成的Core(使用-ddump-simpl
选项)(查看展开,而不是主要定义):
\ (v_aQC [Occ=Once1!] :: Value) ->
case v_aQC of {
__DEFAULT -> GHC.Types.I# 3#;
A -> GHC.Types.I# 4#;
D -> GHC.Types.I# 2#;
E -> GHC.Types.I# 1#
}
英文:
Solution 1: Use Template Haskell to generate the code you want.
Solution 2 (expanded below): (Ab)use the simplifier.
The main obstacle to simplification is that GHC will not inline recursive functions. One workaround is to do the recursion through type classes.
-- Intuitively unroll :: Nat -> (a -> a) -> (a -> a)
-- but the nat is now a type-level parameter.
class Unroll (n :: Nat) where
unroll :: (a -> a) -> (a -> a)
instance Unroll 0 where
unroll = id
instance {-# OVERLAPPABLE #-} Unroll (n-1) => Unroll n where
unroll f = f . unroll @(n-1) f
This lets you define the following fixpoint operator that unfolds the first n iterations:
unrollfix :: forall n a. Unroll n => (a -> a) -> a
unrollfix f = unroll @n f (fix f)
You then need to write all recursive functions using fix
, and replace fix
with unrollfix
. You have to sprinkle some INLINE
pragmas around too.
elem
with fix
:
elem :: forall a. Eq a => a -> [a] -> Bool
elem = fix go
where
go elem_ x [] = False
go elem_ x (y : ys) = x == y || elem_ x ys
elem
with unrollfix
:
{-# INLINE uelem #-}
uelem :: forall n a. (Unroll n, Eq a) => a -> [a] -> Bool
uelem = unrollfix @n go
where
go elem_ x [] = False
go elem_ x (y : ys) = x == y || elem_ x ys
Also length
(omitted), and prec'
.
prec'
with fix
:
prec' :: forall a. Eq a => [[a]] -> a -> Int
prec' = fix go
where
go prec_ [] v = 0
go prec_ (vs : rest) v = if elem v vs
then 1 + length rest
else prec_ rest v
prec'
with unrollfix
:
prec' :: forall n a. (Unroll n, Eq a) => [[a]] -> a -> Int
prec' = unrollfix @n go
where
go prec_ [] v = 0
go prec_ (vs : rest) v = if uelem @n v vs
then 1 + ulength @n rest
else prec_ rest v
{-# INLINE go #-}
Finally, set the n
parameter to a high enough value to enable simplification.
prec :: Value -> Int
prec v = prec' @5 order v
Full code:
{-# LANGUAGE AllowAmbiguousTypes, DataKinds, MultiParamTypeClasses, ScopedTypeVariables, TypeApplications, UndecidableInstances #-}
{-# OPTIONS_GHC -ddump-simpl #-}
module A (Value(..), prec) where
import GHC.TypeNats
import Data.Function (fix)
import GHC.Exts
class Unroll (n :: Nat) where
unroll :: (a -> a) -> (a -> a)
instance Unroll 0 where
unroll = id
instance {-# OVERLAPPABLE #-} Unroll (n-1) => Unroll n where
unroll f = f . unroll @(n-1) f
unrollfix :: forall n a. Unroll n => (a -> a) -> a
unrollfix f = unroll @n f (fix f)
data Value = A | B | C | D | E
deriving Eq
order :: [[Value]]
order = [[A], [B, C], [D], [E]]
{-# INLINE uelem #-}
uelem :: forall n a. (Unroll n, Eq a) => a -> [a] -> Bool
uelem = unrollfix @n go
where
go elem_ x [] = False
go elem_ x (y : ys) = x == y || elem_ x ys
{-# INLINE go #-}
{-# INLINE ulength #-}
ulength :: forall n a. Unroll n => [a] -> Int
ulength = unrollfix @n go
where
go length_ [] = 0
go length_ (_ : xs) = 1 + length_ xs
{-# INLINE go #-}
prec' :: forall n a. (Unroll n, Eq a) => [[a]] -> a -> Int
prec' = unrollfix @n go
where
{-# INLINE go #-}
go prec_ [] v = 0
go prec_ (vs : rest) v = if uelem @n v vs
then 1 + ulength @n rest
else prec_ rest v
prec :: Value -> Int
prec v = prec' @5 order v
Generated Core (using the -ddump-simpl
option) (look at the unfolding, instead of the main definition):
\ (v_aQC [Occ=Once1!] :: Value) ->
case v_aQC of {
__DEFAULT -> GHC.Types.I# 3#;
A -> GHC.Types.I# 4#;
D -> GHC.Types.I# 2#;
E -> GHC.Types.I# 1#
}
答案2
得分: 4
我会简单地围绕该函数使用一个通用记忆化工具,比如MemoTrie
、memoize
或fastmemo
。
{-# LANGUAGE DeriveGeneric, DeriveAnyClass #-}
import Data.Function.FastMemo
data Value = A | B | C | D | E ...
deriving (Eq, Ord, Show, Generic, Memoizable)
prec :: Value -> Int
prec = memoize $ prec' order
这可能不如直接的 TH 解决方案优化,但由 Generic
生成的 Memoizable
实例应该产生一个相当相似的结果。不确定这些包中哪个做得最好。
英文:
I would just wrap a general-purpose memoization tool around the function, like MemoTrie
or memoize
or fastmemo
.
{-# LANGUAGE DeriveGeneric, DeriveAnyClass #-}
import Data.Function.FastMemo
data Value = A | B | C | D | E ...
deriving (Eq, Ord, Show, Generic, Memoizable)
prec :: Value -> Int
prec = memoize $ prec' order
This may not be as optimized as a direct TH solution, but the Generic
-derived Memoizable
instance should result something reasonably similar. Not sure which of these packages do it best.
答案3
得分: 4
定义prec
,然后使用prec
生成order
。
prec :: Value -> Int
prec = \case
A -> 2
B -> 3
C -> 1
D -> 2
E -> 0
order :: [[Value]]
order = go [A, B, C, D, E]
where eqPrec = (==) `on` prec
ordPrec = compare `on` prec
go = reverse . groupBy eqPrec . sortBy ordPrec
请注意,我保留了代码中的英文标识符,只对注释进行了翻译。
英文:
Define prec
first, then generate order
using prec
.
prec :: Value -> Int
prec = \case
A -> 2
B -> 3
C -> 1
D -> 2
E -> 0
order :: [[Value]]
order = go [A, B, C, D, E]
where eqPrec = (==) `on` prec
ordPrec = compare `on` prec
go = reverse . groupBy eqPrec . sortBy ordPrec
答案4
得分: 3
以下是翻译好的部分:
也许一个简单的解决方案是手动执行你希望编译器在拥有你描述的case
语句后自动执行的操作 - 计算一个跳转表。
import Data.Array
-- 使用 Enum 衍生使得编译器生成类似于你所期望的 `case` 语句
data Foo = A | B | C | D | E deriving (Bounded, Enum)
orderArray :: Array Int Int
orderArray = listArray
(0, fromEnum (maxBound :: Foo) - 1)
(orderSlow <$> [minBound..maxBound])
prec :: Foo -> Int
prec = unsafeAt orderArray . fromEnum
这将需要对每个可能的值运行orderSlow
一次,但第二次访问将是快速的O(1)查找。
英文:
Perhaps one simple solution would be to do by hand what you want the compiler to do automatically once it has the case
statement you describe -- compute a jump table.
import Data.Array
-- deriving Enum makes a compiler-written case statement like what you want
data Foo = A | B | C | D | E deriving (Bounded, Enum)
orderArray :: Array Int Int
orderArray = listArray
(0, fromEnum (maxBound :: Foo) - 1)
(orderSlow <$> [minBound..maxBound])
prec :: Foo -> Int
prec = unsafeAt orderArray . fromEnum
This will have to run orderSlow
once for each possible value, but second accesses will be fast O(1) lookups.
答案5
得分: 2
将你的 prec'
函数使用 _fold_
重写,这种情况下使用 foldl'
,并为 prec'
和 order
添加 INLINE
pragma。
import Data.Foldable
import Data.Maybe
data Value = A | B | C | D | E
deriving (Eq)
prec' :: Value -> [[Value]] -> Int
prec' v = fromMaybe 0 . foldl' f Nothing
where
f (Just !len) _ = Just (len + 1)
f Nothing vs | v `elem` vs = Just 1
| otherwise = Nothing
{-# INLINE prec' #-}
order :: [[Value]]
order = [
[ B ],
[ A, D ],
[ C ]
]
{-# INLINE order #-}
prec :: Value -> Int
prec v = prec' v order
使用 GHC 9.6 编译此代码后,我得到以下简化器输出(使用 -ddump-simpl
和一些 -dsuppress-*
选项):
prec1 :: Int
prec1 = I# 0#
prec2 :: Int
prec2 = I# 1#
-- snip --
lvl1 :: Int
lvl1 = I# 2#
lvl2 :: Int
lvl2 = I# 3#
prec :: Value -> Int
prec
= \ (v :: Value) ->
case v of {
__DEFAULT -> lvl1;
B -> lvl2;
C -> prec2;
E -> prec1
}
为什么这个代码能工作?
当你编写列表文字时,GHC 实际上会使用 build
函数进行它的宏展开,这意味着如果你使用一个 fold 函数,你的函数有可能参与列表融合。例如,如果我使用 -ddump-ds
(将 GHC Core 输出为 GHC Core)编译你的 order
,我得到以下结果:
order
= build
(\ (@a_d1Pf)
(c_d1Pg :: [Value] -> a_d1Pf -> a_d1Pf)
(n_d1Ph :: a_d1Pf) ->
c_d1Pg
(build
(\ (@a_d1P4)
(c_d1P5 :: Value -> a_d1P4 -> a_d1P4)
(n_d1P6 :: a_d1P4) ->
c_d1P5 B n_d1P6))
(c_d1Pg
(build
(\ (@a_d1P9)
(c_d1Pa :: Value -> a_d1P9 -> a_d1P9)
(n_d1Pb :: a_d1P9) ->
c_d1Pa A (c_d1Pa D n_d1Pb)))
(c_d1Pg
(build
(\ (@a_d1Pc)
(c_d1Pd :: Value -> a_d1Pc -> a_d1Pc)
(n_d1Pe :: a_d1Pc) ->
c_d1Pd C n_d1Pe))
n_d1Ph)))
因此,通过足够的内联和事实上,elem
和 foldl'
都参与列表融合作为好的消费者,简化器可以使用重写规则积极地优化你的函数。
英文:
Rewrite your prec'
function using a fold, in this case foldl'
, and add INLINE
pragmas to prec'
and order
.
import Data.Foldable
import Data.Maybe
data Value = A | B | C | D | E
deriving (Eq)
prec' :: Value -> [[Value]] -> Int
prec' v = fromMaybe 0 . foldl' f Nothing
where
f (Just !len) _ = Just (len + 1)
f Nothing vs | v `elem` vs = Just 1
| otherwise = Nothing
{-# INLINE prec' #-}
order :: [[Value]]
order = [
[ B ],
[ A, D ],
[ C ]
]
{-# INLINE order #-}
prec :: Value -> Int
prec v = prec' v order
Compiling this with GHC 9.6 I get the following simplifier output (using -ddump-simpl
and some -dsuppress-*
options):
prec1 :: Int
prec1 = I# 0#
prec2 :: Int
prec2 = I# 1#
-- snip --
lvl1 :: Int
lvl1 = I# 2#
lvl2 :: Int
lvl2 = I# 3#
prec :: Value -> Int
prec
= \ (v :: Value) ->
case v of {
__DEFAULT -> lvl1;
B -> lvl2;
C -> prec2;
E -> prec1
}
Why does this work?
When you write a list literal, GHC will actually desugar it using the build
function, which means that if you use a fold then your function can potentially participate in list fusion. For example, if I compile your order
with -ddump-ds
(which dumps the desugaring output as GHC Core) then I get the following result:
order
= build
(\ (@a_d1Pf)
(c_d1Pg :: [Value] -> a_d1Pf -> a_d1Pf)
(n_d1Ph :: a_d1Pf) ->
c_d1Pg
(build
(\ (@a_d1P4)
(c_d1P5 :: Value -> a_d1P4 -> a_d1P4)
(n_d1P6 :: a_d1P4) ->
c_d1P5 B n_d1P6))
(c_d1Pg
(build
(\ (@a_d1P9)
(c_d1Pa :: Value -> a_d1P9 -> a_d1P9)
(n_d1Pb :: a_d1P9) ->
c_d1Pa A (c_d1Pa D n_d1Pb)))
(c_d1Pg
(build
(\ (@a_d1Pc)
(c_d1Pd :: Value -> a_d1Pc -> a_d1Pc)
(n_d1Pe :: a_d1Pc) ->
c_d1Pd C n_d1Pe))
n_d1Ph)))
So with enough inlining and the fact that both elem
and foldl'
participate in list fusion as good consumers, the simplifier can optimize your function aggressively using rewrite rules.
答案6
得分: 1
以下是一个TH解决方案(template-haskell-2.19.0
):
{-# LANGUAGE TemplateHaskell #-}
module PrecTH where
import Language.Haskell.TH
import Data.List (nub)
prec' :: Ord a => [[a]] -> a -> Int
prec' [] _ = 0
prec' (vs : rest) v = if v `elem` vs
then 1 + length rest
else prec' rest v
mkPrecValueDataType :: String -> [[String]] -> DecsQ
mkPrecValueDataType dtName order = pure [
DataD [] dtName' [] Nothing
[ NormalC (mkName c) []
| c <- cstrs ]
[DerivClause Nothing $ ConT <$> [''Eq, ''Ord, ''Show]]
, SigD (mkName "prec")
$ ArrowT `AppT` ConT dtName' `AppT` ConT ''Int
, FunD (mkName "prec")
[ Clause [ConP (mkName c) [] []]
(NormalB . LitE . IntegerL $ slowPrec c)
[]
| c <- cstrs ]
]
where slowPrec = fromIntegral . prec' order
cstrs = concat order -- apply `sort` here if you like the constructors in alphabetical order
dtName' = mkName dtName
要这样使用:
{-# LANGUAGE TemplateHaskell #-}
module PrecValues where
import PrecTH
mkPrecValueDataType "Value" [["B"], ["A","D"], ["C"]]
并产生以下输出:
ghci> :browse
prec :: Value -> Int
type Value :: *
data Value = B | A | D | C
ghci> prec<$>[A,B,C,D]
[2,3,1,2]
英文:
For completeness, here is a TH solution (template-haskell-2.19.0
):
{-# LANGUAGE TemplateHaskell #-}
module PrecTH where
import Language.Haskell.TH
import Data.List (nub)
prec' :: Ord a => [[a]] -> a -> Int
prec' [] _ = 0
prec' (vs : rest) v = if v `elem` vs
then 1 + length rest
else prec' rest v
mkPrecValueDataType :: String -> [[String]] -> DecsQ
mkPrecValueDataType dtName order = pure [
DataD [] dtName' [] Nothing
[ NormalC (mkName c) []
| c <- cstrs ]
[DerivClause Nothing $ ConT <$> [''Eq, ''Ord, ''Show ]]
, SigD (mkName "prec")
$ ArrowT `AppT` ConT dtName' `AppT` ConT ''Int
, FunD (mkName "prec")
[ Clause [ConP (mkName c) [] []]
(NormalB . LitE . IntegerL $ slowPrec c)
[]
| c <- cstrs ]
]
where slowPrec = fromIntegral . prec' order
cstrs = concat order -- apply `sort` here if you like the constructors in alphabetical order
dtName' = mkName dtName
To be used thus
{-# LANGUAGE TemplateHaskell #-}
module PrecValues where
import PrecTH
mkPrecValueDataType "Value" [["B"], ["A","D"], ["C"]]
and producing
ghci> :browse
prec :: Value -> Int
type Value :: *
data Value = B | A | D | C
ghci> prec<$>[A,B,C,D]
[2,3,1,2]
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论