英文:
getting C99CodePrinter to output float literals
问题
我正在尝试让sympy将一些公式转换为一种与C语言几乎相同的语言中的代码。然而,这种语言需要在浮点文字末尾加上'f',否则它会提升为双精度浮点数。即使sympy提供的文字被明确地指定为32位浮点数,它也不会这样做。我该如何做才能实现这个?
注意:
- 我在示例中直接使用了C99CodePrinter,因为我正在对其进行子类化。
- 目标语言是OpenCL,这就是为什么我们使用浮点数而不是双精度浮点数的原因(一些GPU不支持双精度浮点数)。
这是一个最简示例:
import sympy as sp
from sympy.printing.c import C99CodePrinter
import numpy as np
myPrinter = C99CodePrinter()
x = sp.symbols('x')
y = np.ones(1, dtype=np.float32)
print(myPrinter.doprint(y[0] * x))
我期望得到的是"1.0f*x",但实际上我得到的是"1.0*x"。
英文:
I'm trying to get sympy to convert some formulas into code in a language almost identical to C. However this language needs float literals to have the 'f' at the end other wise it promotes it to a double. Even when sympy is provided a literal that is explicitly typed as a 32bit float it doesn't do this. How do I make it do this?
notes:
- I'm using C99CodePrinter directly in my example because I'm subclassing it.
- target language is opencl, which is why we are using floats instead of doubles (some GPUs don't support doubles)
here is a minimal example:
import sympy as sp
from sympy.printing.c import C99CodePrinter
import numpy as np
myPrinter = C99CodePrinter()
x = sp.symbols('x')
y = np.ones(1,dtype=np.float32)
print(myPrinter.doprint(y[0]*x))
I was expecting to get "1.0f*x" but I actually get "1.0*x"
答案1
得分: 2
以下是您要求的代码的翻译部分:
所以在这里回答我的问题。 正确的最小示例应该是
```python
import sympy as sp
from sympy.printing.c import C99CodePrinter
from sympy.codegen.ast import real, float32
import numpy as np
myPrinter = C99CodePrinter(settings={"type_aliases":{real: float32}})
x = sp.symbols('x')
y = np.ones(1,dtype=np.float32)
print(myPrinter.doprint(y[0]*x))
似乎打印机默认假定每个实数都是float64,除非像所示覆盖。
<details>
<summary>英文:</summary>
so answering my own question here. the correct minimal example would have been
import sympy as sp
from sympy.printing.c import C99CodePrinter
from sympy.codegen.ast import real, float32
import numpy as np
myPrinter = C99CodePrinter(settings={"type_aliases":{real: float32}})
x = sp.symbols('x')
y = np.ones(1,dtype=np.float32)
print(myPrinter.doprint(y[0]*x))
seems like the printer defaults to assuming every real number is float64 unless its overridden as shown.
</details>
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论