英文:
Error in a custom sympy function when trying to print the squared function with latex
问题
我想编写一个带有特殊LaTeX输出的自定义SymPy函数。一个简化的函数可能看起来像下面这样:
```python
from sympy import Function, Symbol, latex
class TestClass(Function):
def _latex(self, printer):
m = self.args[0]
_m = printer._print(m)
return _m + '(x)';
S = Symbol('S')
latex(TestClass(S)**2)
然而,这给了我一个错误
TypeError: TestClass._latex() 收到了一个意外的关键字参数 'exp'
请问有人可以帮我理解这里出了什么问题吗?
<details>
<summary>英文:</summary>
I would like to write a custom SymPy Function with special latex output. A simplified function might look like the following:
```python
from sympy import Function, Symbol, latex
class TestClass(Function):
def _latex(self, printer):
m = self.args[0]
_m = printer._print(m)
return _m + '(x)'
S = Symbol('S')
latex(TestClass(S)**2)
However, this gives me the error
> TypeError: TestClass._latex() got an unexpected keyword argument 'exp'
Can someone please help me to understand what is going wrong here?
答案1
得分: 1
TestClass
是 Function
的子类。
TestClass(S)**2
是 Pow
的一个实例。让我们看一下 LatexPrinter
类的源代码,具体是 _print_Pow
方法:
if expr.base.is_Function:
return self._print(expr.base, exp=self._print(expr.exp))
在底层,self._print
检查 expr.base
是否实现了 _latex
方法。如果是的话(就像在你的测试案例中一样),它会调用它并传递关键字参数 exp
。
所以,你需要调整你的代码:
class TestClass(Function):
def _latex(self, printer, exp=None):
m = self.args[0]
_m = printer.doprint(m)
base = _m + '(x)'
if exp is None:
return base
return base + "^{%s}" % exp
S = Symbol('S')
expr = TestClass(S)
latex(expr**3)
# 输出: 'S(x)^{3}'
英文:
TestClass
is a subclass of Function
.
TestClass(S)**2
is an instance of Pow
. Let's look at the source code of the LatexPrinter
class, specifically at the _print_Pow
method:
if expr.base.is_Function:
return self._print(expr.base, exp=self._print(expr.exp))
Under the hood, self._print
checks if expr.base
implements the _latex
method. If it does (like in your test case), it calls it and pass along the keyword arguments, exp
.
So, you need to adjust your code:
class TestClass(Function):
def _latex(self, printer, exp=None):
m = self.args[0]
_m = printer.doprint(m)
base = _m + '(x)'
if exp is None:
return base
return base + "^{%s}" % exp
S = Symbol('S')
expr = TestClass(S)
latex(expr**3)
# out: 'S(x)^{3}'
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论