0%

流畅的 Python 第 2 版(12):序列的特殊方法

这篇文章将实现一个多维向量的 Vector 类。这个类的行为与 Python 中标准的不可变扁平序列一样。这篇文章还将讨论一个概念:把协议当作正式接口。我们将说明协议和鸭子类型之间的关系,以及对自定义类型的实际影响。

Vector类:用户定义的序列类型

这里将使用组合的方式实现 Vector 类,而不使用继承。向量的分量存储在浮点数数组中,而且还将实现不可变扁平序列所需的方法。

信息检索领域经常使用 N 维向量(N是很大的数)​,因为查询的文档和文本使用向量表示,一个单词一个维度。这叫向量空间模型。在这个模型中,一个关键的相关指标是余弦相关性(表示查询的向量与表示文档的向量之间夹角的余弦)​。夹角越小,余弦值越趋近于 1,文档与查询的相关性越大。如果在实际使用中需要做向量运算,那么应该使用 NumPy 和 SciPy。

如下是 Vector 类实现的第一版:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from array import array
import reprlib
import math


class Vector:
typecode = 'd'

def __init__(self, components):
self._components = array(self.typecode, components)

def __iter__(self):
return iter(self._components)

def __repr__(self):
components = reprlib.repr(self._components)
components = components[components.find('['):-1]
return f'Vector({components})'

def __str__(self):
return str(tuple(self))

def __bytes__(self):
return (bytes([ord(self.typecode)]) +
bytes(self._components))

def __eq__(self, other):
return tuple(self) == tuple(other)

def __abs__(self):
return math.hypot(*self)

def __bool__(self):
return bool(abs(self))

@classmethod
def frombytes(cls, octets):
typecode = chr(octets[0])
memv = memoryview(octets[1:]).cast(typecode)
return cls(memv)
  • 使用 reprlib.repr 用于生成大型结构或递归结构的安全表示形式,它会限制输出字符串的长度,用 ... 表示截断的部分

协议和鸭子类型

在Python中创建功能完善的序列类型无须使用继承,实现符合序列协议的方法即可。那这里的协议是什么呢?在面向对象编程中,协议是非正式的接口,只在文档中定义,不在代码中定义。例如,Python 的序列协议只需要 __len____getitem__ 这两个方法。

  • 任何类(例如 Spam)​,只要使用标准的签名和语义实现了这两个方法,就能用在任何预期序列的地方
  • Spam 是不是哪个类的子类无关紧要,只要提供了所需的方法即可
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import collections

Card = collections.namedtuple('Card', ['rank', 'suit'])

class FrenchDeck:
ranks = [str(n) for n in range(2, 11)] + list('JQKA')
suits = 'spades diamonds clubs hearts'.split()

def __init__(self):
self._cards = [Card(rank, suit) for suit in self.suits
for rank in self.ranks]

def __len__(self):
return len(self._cards)

def __getitem__(self, position):
return self._cards[position]

对于上述代码,任何有经验的 Python 程序员只要看一眼就知道它是序列,即便它是 object 的子类也无妨。我们说它是序列,因为它的行为像序列,这才是重点

人们称其为鸭子类型(duck typing)。协议是非正式的,没有强制力,因此如果知道类的具体使用场景,那么通常只需要实现协议的一部分。例如,为了支持迭代,只需实现 __getitem__ 方法,没必要提供 __len__ 方法。

实现 PEP 544—Protocols: Structural subtyping (static duck typing) 之后,Python3.8 开始支持协议类(protocol class)。这里的协议与我们上面所讲的 传统协议 有关系但又不完全相同。如果需要区分:

  • 可以使用 静态协议 指代协议类规定的协议
  • 使用动态协议指代传统意义上的协议
  • 二者之间主要的区别是,静态协议的实现必须提供静态类中定义的所有方法

可切片的序列

如果能委托给对象中的序列属性(例如 self._components 数组)​,则支持序列协议特别简单:

1
2
3
4
5
6
class Vector:
def __len__(self):
return len(self._components)

def __getitem__(self, index):
return self._components[index]
1
2
3
4
5
6
7
8
9
>>> v1 = vector.Vector([3, 4, 5])
>>> len(v1)
3
>>> v1[0]
3.0
>>> v1[-1]
5.0
>>> v1[0:2]
array('d', [3.0, 4.0])

可以看到,虽然已经支持了切片,但是实现并不算完美,因为最好 Vector 实例的切片也是 Vector 的实例。想想内置序列类型:切片得到的都是各自类型的新实例,而不是其他类型。

为了把 Vector 实例的切片也变成 Vector 实例,不能简单地把切片操作委托给数组。要分析传给 __getitem__ 方法的参数,做适当的处理:

1
2
3
4
5
6
def __getitem__(self, key):
if isinstance(key, slice):
cls = type(self)
return cls(self._components[key])
index = operator.index(key)
return self._components[index]
  • 如果 key 是切片,则返回一个 Vector 实例,对应 self._components 的切片结果
  • 否则,使用 operator.index 函数将 key 转换为索引(如果转换失败,会抛出异常,以确认 key 是否是有效的索引类型)
  • 其实 key 还是元组类型(包括多个 slice 或 index),以支持多维切片,但是我们这里不支持

大量使用 isinstance 可能表明面向对象设计得不好,不过在 __getitem__ 方法中使用它处理切片是合理的。

动态存取属性

如果能通过单个字母访问前几个分量的话会比较方便。例如,用 x、y 和 z 代替 v[0]​、v[1] 和 v[2]​。在 Vector2d 中,使用 @property 装饰器把 x 和 y 标记为只读特性。但这样太麻烦,特殊方法 __getattr__ 提供了更好的方式。属性查找失败后,解释器会调用 __getattr__ 方法。简单来说,对于 my_obj.x 表达式:

  • Python 会检查 my_obj 实例有没有名为 x 的属性
  • 如果没有,就到类(my_obj.__class__)中查找
  • 如果还没有,就沿着继承图继续向上查找
  • 如果依旧找不到,则调用 my_obj 所属的类中定义的 __getattr__ 方法,传入 self 和属性名称的字符串形式(例如 ‘x’)​
1
2
3
4
5
6
7
8
9
10
11
12
__match_args__ = ('x', 'y', 'z', 't')

def __getattr__(self, name):
cls = type(self)
try:
pos = cls.__match_args__.index(name)
except ValueError:
pos = -1
if 0 <= pos < len(self._components):
return self._components[pos]
msg = f'{cls.__name__!r} object has no attribute {name!r}'
raise AttributeError(msg)
  • 设定 __match_args__,让 __getattr__ 实现的动态属性支持位置模式匹配
  • __match_args__ 一般有两个作用:在 case 子句中使用时支持位置模式,而是存储 getattr / setattr_ 的特殊逻辑实现的动态属性名称
1
2
3
4
5
6
7
8
9
10
>>> v = Vector(range(5))
>>> v.x
0.0
>>> v.y
1.0
>>> v.x = 10
>>> v.x
10
>>> v
Vector([0.0, 1.0, 2.0, 3.0, 4.0])
  • 这里的行为有些怪,为 v.x 设置新值后,v.x 返回 10,但是向量中的分量数组却没有变化
  • 仅当对象没有指定名称的属性时,Python 才会调用 __getattr__ 方法,这是一种后备机制。
  • v.x = 10 这样赋值之后,v 对象就有 x 属性了,因此使用 v.x 获取 x 属性的值时不会再调用 __getattr__ 方法,解释器会直接返回 v.x 绑定的值,即 10

为了避免这种前后矛盾的现象,需要改写 Vector 类中设置属性的逻辑,即实现 __setattr__ 方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
def __setattr__(self, name, value):
cls = type(self)
if len(name) == 1:
if name in cls.__match_args__:
error = 'readonly attribute {attr_name!r}'
elif name.islower():
error = "can't set attributes 'a' to 'z' in {cls_name!r}"
else:
error = ''
if error:
msg = error.format(cls_name=cls.__name__, attr_name=name)
raise AttributeError(msg)
super().__setattr__(name, value)
  • 对名称是单个字符的属性进行检查
  • 默认情况,则在超类上调用 __setattr__ 方法,提供标准行为
  • super() 函数用于动态访问超类的方法,对 Python 这种支持多重继承的动态语言来说,必须这么做。程序员经常使用这个函数把子类方法的某些任务委托给超类中适当的方法
  • 另外如果想实现修改分量,还可以通过实现 __setitem__ 方法来支持 v[0] = 1.1 这样的赋值

我们知道,在类中声明 __slots__ 属性可以防止设置新实例属性。因此,你可能想使用这个功能,而不像这里所做的那样实现 __setattr__ 方法。但是不建议只为了避免创建实例属性而使用 __slots____slots__ 只应该用于节省内存,而且仅当内存严重不足时才应该这么做。

大多数时候,如果实现了 __getattr__ 方法,那么也要定义 __setattr__ 方法,以防对象的行为不一致

哈希和快速等值测试

我们要再次实现 __hash__ 方法,加上现有的 __eq__ 方法,这会把 Vector 实例变成可哈希的对象。我们将使用 ^(异或)运算符依次计算各个分量的哈希值,就像这样:v[0]^ v[1]^ v[2]​。这正是 functools.reduce 函数的作用。

  • reduce() 的关键思想是,把一系列值归约成单个值
  • reduce() 函数的第一个参数是一个接受两个参数的函数,第二个参数是一个可迭代对象
1
2
3
4
5
6
7
>>> import functools
>>> functools.reduce(lambda a, b: a * b, range(1, 6))
120

>>> import operator
>>> functools.reduce(operator.xor, range(6))
1

operator 模块以函数的形式提供了所有的 Python 中缀运算符,借此可以减少使用 lambda 表达式的必要。因此 Vector 类通过如下方法支持 hash 测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from array import array
import reprlib
import math
import functools
import operator


class Vector:
typecode = 'd'

def __eq__(self, other):
return tuple(self) == tuple(other)

def __hash__(self):
hashes = (hash(x) for x in self._components)
return functools.reduce(operator.xor, hashes, 0)
  • 创建一个生成器表达式,惰性计算各个分量的哈希值
  • 把 hashes 提供给 reduce 函数,使用 xor 函数计算聚合的哈希值。第三个参数(0)是初始值

归约过程则使用 xor 运算符聚合所有的哈希值。把生成器表达式替换成 map 函数,映射过程更明显:

1
2
3
def __hash__(self):
hashes = map(hash, self._components)
return functools.reduce(operator.xor, hashes)

为了提高 __eq__ 的性能,我们不再构造元组并进行元组的比较了,而是直接比较元素:

1
2
3
4
5
6
7
def __eq__(self, other):
if len(self) != len(other):
return False
for a, b in zip(self, other):
if a != b:
return False
return True
  • zip 函数生成一个由元组构成的生成器,元组中的元素来自参数传入的各个可迭代对象
  • 一旦有一个输入对象耗尽,zip 函数就会立即停止生成值,而且不发出警告。因此首先进行长度比较是必要的
  • itertools.zip_longest 函数的行为有所不同,它使用可选的 fillvalue(默认值为 None)来填充缺失的值,因此可以继续生成元组,直到最后一个可迭代对象耗尽
  • Python3.10 zip 函数增加一个可选的参数 strict,如果各个可迭代对象的长度不同,那么 zip 就应该抛出 ValueError
  • zip 函数的名称取自拉链,因为此物品把两边的链牙咬合在一起,这形象地说明了 zip(left, right) 的作用
1
2
3
>>> a = [(1, 2, 3), (4, 5, 6)]
>>> list(zip (*a))
[(1, 4), (2, 5), (3, 6)]

上面的 __eq__ 函数更简单的写法如下:

1
2
def __eq__(self, other):
return len(self) == len(other) and all(a == b for a, b in zip(self, other))

格式化

如下代码实现了以球面坐标的形式展示 Vector 向量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def angle(self, n):
r = math.hypot(*self[n:])
a = math.atan2(r, self[n-1])
if (n == len(self) - 1) and (self[-1] < 0):
return math.pi * 2 - a
else:
return a

def angles(self):
return (self.angle(n) for n in range(1, len(self)))

def __format__(self, fmt_spec=''):
if fmt_spec.endswith('h'):
fmt_spec = fmt_spec[:-1]
coords = itertools.chain([abs(self)],
self.angles())
outer_fmt = '<{}>'
else:
coords = self
outer_fmt = '({})'
components = (format(c, fmt_spec) for c in coords)
return outer_fmt.format(', '.join(components))
  • 使用 itertools.chain 函数生成生成器表达式,无缝迭代向量的模和各个角坐标