英文:
Typehinting function to accept numpy arrays
问题
I'm having trouble understanding why mypy throws an error with the following example.
import numpy as np
from typing import Sequence
def compute(x: Sequence[float]) -> bool:
# some computation that does not modify x
...
compute(np.linspace(0, 1, 10))
Mypy Error:
Argument 1 to "compute" has incompatible type "ndarray[Any, dtype[floating[Any]]]"; expected "Sequence[float]" [arg-type]
In particular, since typing.Sequence requires iterability, reversability, and indexing, I'd think a numpy array should also be a Sequence. Is it related to the fact that numpy arrays are mutable, and Sequence types are intended to be immutable? I notice when I change Sequence to Iterable, the problem is fixed. But I need to be able to index x in compute.
So, what is the best way to typehint the compute function so that it can accept objects with iterability and indexing?
英文:
I'm having trouble understanding why mypy throws an error with the following example.
import numpy as np
from typing import Sequence
def compute(x: Sequence[float]) -> bool:
# some computation that does not modify x
...
compute(np.linspace(0, 1, 10))
Mypy Error:
Argument 1 to "compute" has incompatible type "ndarray[Any, dtype[floating[Any]]]"; expected "Sequence[float]" [arg-type]
In particular, since typing.Sequence requires iterability, reversability, and indexing, I'd think a numpy array should also be a Sequence. Is it related to the fact that numpy arrays are mutable, and Sequence types are intended to be immutable? I notice when I change Sequence to Iterable, the problem is fixed. But I need to be able to index x in compute.
So, what is the best way to typehint the compute function so that it can accept objects with iterability and indexing?
答案1
得分: 3
np.ndarray 不是一个 Sequence(与相应的协议不兼容),因为 np.ndarray 没有实现 .count 和 .index 方法,这是 collections.abc.Sequence 所要求的,可以在此链接查看方法表。
这里有一个相关问题。
要使某个类似于序列的东西与 np.ndarray 兼容,我建议定义你自己的协议:
from __future__ import annotations
from collections.abc import Collection, Reversible, Iterator
from typing import Protocol, TypeVar, overload
_T_co = TypeVar('_T_co', covariant=True)
class WeakSequence(Collection[_T_co], Reversible[_T_co], Protocol[_T_co]):
@overload
def __getitem__(self, index: int) -> _T_co: ...
@overload
def __getitem__(self, index: slice) -> WeakSequence[_T_co]: ...
def __contains__(self, value: object) -> bool: ...
def __iter__(self) -> Iterator[_T_co]: ...
def __reversed__(self) -> Iterator[_T_co]: ...
这基本上与typeshed中的Sequence定义相同,只是我使用了Protocol而不是Generic(以避免添加实现 - 存根不需要,而且这样更容易理解结构子类型化),并省略了.index和.count方法。
现在,np.ndarray 应该与 WeakSequence 以及 list 或 tuple 兼容,你可以在注释中使用 WeakSequence 而不是 collections.abc.Sequence。
英文:
np.ndarray is not a Sequence (incompatible with corresponding protocol), because np.ndarray does not implement .count and .index methods, which are required for collections.abc.Sequence, see the table of methods in link.
Here's a relevant question.
To make something sequence-like compatible with np.ndarray, I'd suggest to define your protocol:
from __future__ import annotations
from collections.abc import Collection, Reversible, Iterator
from typing import Protocol, TypeVar, overload
_T_co = TypeVar('_T_co', covariant=True)
class WeakSequence(Collection[_T_co], Reversible[_T_co], Protocol[_T_co]):
@overload
def __getitem__(self, index: int) -> _T_co: ...
@overload
def __getitem__(self, index: slice) -> WeakSequence[_T_co]: ...
def __contains__(self, value: object) -> bool: ...
def __iter__(self) -> Iterator[_T_co]: ...
def __reversed__(self) -> Iterator[_T_co]: ...
This is basically the same as Sequence definition in typeshed except that I'm using Protocol instead of Generic (to avoid adding implementation - stubs don't have to, plus this makes structural subtyping more obvious) and omit .index and .count methods.
Now np.ndarray should be compatible with WeakSequence as well as a list or tuple, and you can use WeakSequence instead of collections.abc.Sequence in your annotations.
答案2
得分: 0
尝试使用np.ndarray作为变量x的类型提示。我检查了type(np.linspace(0, 1, 10)),它返回了np.ndarray。
英文:
Try using np.ndarray as the type hint for x. I checked type(np.linspace(0, 1, 10)) and it returned np.ndarray
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。


评论