在Julia中将函数作为参数传递(最佳实践 / 代码设计)

huangapple go评论57阅读模式
英文:

A function as an argument in Julia (best practice / code design)

问题

以下是翻译好的内容:

Julia编程语言中以下问题的最佳/最优雅的代码设计是什么:

模块 mod1mod2mod3 等实现了一个复杂的函数 fun(a,b,c),该函数计算并返回一个矩阵 M

我的程序中的一个函数 foo 需要重复使用函数 fun 来计算矩阵 M,并且输入参数 a, b, c 可能会变化。然而,我的程序的用户可能会指定使用哪个模块来执行 fun。请注意,用户不是程序员,只想输入类似 julia program.jl 的命令,并在某个输入文件中使用类似 "MODULE = mod2" 的字符串来指定所使用的模块。

最不优雅和最原始的解决方案可能是使用 if...elseif... 语句,如下所示:

function foo(...)
[...]
    if SpecifiedModule == "mod1"
        M = mod1.fun(a, b, c)
    elseif SpecifiedModule == "mod2"
        M = mod2.fun(a, b, c)
    elseif SpecifiedModule == "mod3"
        M = mod3.fun(a, b, c)
    end
    # 处理 M

另一种方法是使用函数指针,或者甚至使用类似于C++中的 "functors"。

在Julia中,对于这种情况,最佳做法是什么?有没有办法摆脱这种丑陋的 if...elseif... 代码?

PS:与在 Optim.jl 模块中找到函数的最优值的情况相似,例如(https://github.com/JuliaNLSolvers/Optim.jl/)。如果我没有弄错的话,用户可以指定一个函数 f 作为参数传递给 optimize(f, ...) 函数。

英文:

What is the best/most elegant code design in the Julia programming language for the following problem:

The modules mod1, mod2, mod3, ... implement a complicated function fun(a,b,c), which calculates and returns a matrix M.

A function foo of my program needs to calculate the matrix M repeatably via the function fun with varying input parameters a,b,c. However, the user of my program can specify which module is used for fun. Note, that the user is no programmer and only wants to type something like julia program.jl and specifies the used module with a string like "MODULE = mod2" in some input file.

The most naive and inelegant solution would probably be via if...elseif...

function foo(...
[...]
    if SpecifiedModule == "mod1"
        M = mod1.fun(a,b,c)
    elseif SpecifiedModule == "mod2"
        M = mod2.fun(a,b,c)
    SpecifiedModule == "mod3"
        M = mod3.fun(a,b,c)
    end
    # process M

An alternative way would be to use function-pointers. Or even functors as known from C++?

What is best practice in Julia for this case? Is there a way to get rid of this ugly if..elseif... code?

PS: there are similarities to the case of finding an optimized value of a function, like in the module Optim.jl (https://github.com/JuliaNLSolvers/Optim.jl/). If I am not wrong, a user specified function f is passed as an argument to the optimize(f, ...) function.

答案1

得分: 3

相同的返回类型

如果返回类型保证相同,您可以使用Match.jl并在Symbol或枚举上进行模式匹配。

import Pkg
Pkg.add("Match")
using Match

function fun(args, SpecifiedModule::Symbol)
    M = @match SpecifiedModule begin
        :mod1 => Mod1.fun(args...)
        :mod2 => Mod2.fun(args...)
        :mod3 => Mod3.fun(args...)
        _ => Mod1.fun(args)  # 默认行为,错误处理等
    end
end

M1 = fun(args, :mod1)  # 使用Mod1调用
M2 = fun(args, :mod2)  # 使用Mod2调用
# ... 等等

如果fun的返回类型保证相同,那么这应该可以正常工作,否则将会存在类型不稳定性。

通用的返回类型,但编译时模块已知

如果模块名称在编译时保证已知,您可以使用Val和符号,将Julia的多分派/函数重载用作模式匹配机制。

get_module(::Val{:mod1}) = Mod1
get_module(::Val{:mod2}) = Mod2
get_module(::Val{:mod3}) = Mod3
# ... 等等

function fun(args, SpecifiedModule)
    get_module(SpecifiedModule).fun(args)
end

fun(args, Val(:mod1))  # 调用Mod1.fun
fun(args, Val(:mod2))  # 调用Mod2.fun

Val的参数只能是编译时已知的isbitstype值,所以不能将例如:mod1分配给一个变量然后传递它。关于Val的一些说明(充分披露,我是作者)。

foo作为高阶函数

就像在Optim.jl的示例中一样,您可以让用户传递一个任意的函数f,例如foo(a, b, c, Mod1.fun)。当输入函数需要相对通用,比如您想要优化、区分等的某个函数时,这是一个自然的选择。我不确定您的特定用例,但如果您想要限制用户的选择为一组特定的函数,那么最好避免这种方法。

英文:

Identical return types

If the return types are guaranteed to be identical, you could use Match.jl and pattern match over Symbols or enums.

import Pkg
Pkg.add("Match")
using Match

function fun(args, SpecifiedModule::Symbol)
    M = @match SpecifiedModule begin
        :mod1 => Mod1.fun(args...)
        :mod2 => Mod2.fun(args...)
        :mod3 => Mod3.fun(args...)
        _ => Mod1.fun(args)  # default behaviour, error handling etc
    end
end

M1 = fun(args, :mod1)  # call with Mod1
M2 = fun(args, :mod2)  # call with Mod2
# ... and so on

This should work fine if the return types of fun are guaranteed to be the same, else there will be type instabilities.

General return types, but module known at compile time

If the module name is guaranteed to be known at compile, you can use Val with symbols to use Julia's multiple dispatch/function overloading as a pattern matching mechanism.

get_module(::Val{:mod1}) = Mod1
get_module(::Val{:mod2}) = Mod2
get_module(::Val{:mod3}) = Mod3
# ... and so on


function fun(args, SpecifiedModule)
    get_module(SpecifiedModule).fun(args)
end

fun(args, Val(:mod1))  # calls Mod1.fun
fun(args, Val(:mod2))  # calls Mod2.fun

Arguments to Val can only be compile-time known isbitstype values, so you cannot assign e.g. :mod1 to a variable and pass it. Some exposition on Val (full disclosure, I am the author).

foo as a higher order function

Just as in the example of Optim.jl, you could let users pass an arbitrary function f like foo(a, b, c, Mod1.fun). This is a natural choice when the input function needs to be somewhat general, such as some function that you want to optimize, differentiate etc. I am not sure about your particular use case, but if you want to restrict the users' choices to a specific set of functions, then it would be better to avoid this.

huangapple
  • 本文由 发表于 2023年6月19日 20:38:50
  • 转载请务必保留本文链接:https://go.coder-hub.com/76506711.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定