0%

流畅的 Python(6):序列的修改、散列和切片

这篇文章将自定义一个序列类型 Vector,进一步探讨 Python 序列的修改、散列和切片等操作所涉及的背后原理。

Vector 类:用户定义的序列类型

这里实现一个 Vector 类,这个类的行为与 Python 中标准的不可变扁平序列基本一致:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#!/usr/bin/env python3

# Copyright (C) fuchencong.com

from array import array
import reprlib
import math


class Vector:
typecode = 'd'

def __init__(self, components):
self._components = array(self.typecode, components)

def __iter__(self):
return iter(self._components)

def __repr__(self):
components = reprlib.repr(self._components)
components = components[components.find('['):-1]
return 'Vecotr({})'.format(components)

def __str__(self):
return str(tuple(self))

def __bytes__(self):
return bytes(ord(self.typecode)) + bytes(self._components)

def __eq__(self, other):
return tuple(self) == tuple(other)

def __abs__(self):
return math.sqrt(sum(x * x for x in self))

def __bool__(self):
return bool(abs(self))

@classmethod
def frombytes(cls, octets):
typecode = chr(octets[0])
memv = memoryview(octets[1:].cast(typecode))
return cls(memv)
1
2
3
4
5
6
7
>>> import vector_v1
>>> v1 = vector_v1.Vector((2, 3, 4))
>>> v2 = vector_v1.Vector(range(10))
>>> v1
Vecotr([2.0, 3.0, 4.0])
>>> v2
Vecotr([0.0, 1.0, 2.0, 3.0, 4.0, ...])
  • 这里 _components 是一个受保护的实例属性
  • 序列类型的构造方法最好接收可迭代的对象作为参数,因为所有内置的序列类型都是这样做的
  • 字符串的表示形式是用于调试,使用 reprlib 模块可以生成长度有限的表示形式。reprlib.repr 用于生成大型结构或递归结构的安全形式,它会限制输出字符串的长度。

协议和鸭子类型

在 Python 中,创建功能完善的序列类型无需使用继承,只需实现符合序列协议的方法。在面对对象编程中,协议是非正式的接口,只在文档中定义,在代码中不定义。例如 Python 序列协议只要实现了 __len____getitem__ 两个方法,就能用于任何期待序列的地方。我们说它是序列,因为它的行为像序列,人们将其称为鸭子类型。

协议是非正式的,没有强制力,如果你知道你的类的具体使用场景,通常只需要实现一个协议的部分。例如如果只是为了支持迭代,只需要实现 __getitem__ 方法,没有必要提供 __len__ 方法。

Vector 类第 2 版:可切片的序列

只需要添加如下方法,Vector 类就支持了序列协议:

1
2
3
4
5
def __len__(self):
return len(self._components)

def __getitem__(self, position):
return self._components[position]
1
2
3
4
5
6
7
8
9
10
>>> import vector_v2
>>> v3 = vector_v2.Vector(range(5))
>>> v3
Vecotr([0.0, 1.0, 2.0, 3.0, 4.0])
>>> v3[1]
1.0
>>> v3[2:3]
array('d', [2.0])
>>> len(v3)
5

现在 Vector 已经支持切片了,但是切片返回的实例是 array 对象,而不是 Vector 实例。内置的序列类型,其切片结果也都是各自类型的新实例,而不是其他类型。为了让 Vector 实例的切片仍然是 Vector 的实例,就不能简单地委托给数组切片。需要分析传递给 getitem 方法的参数,做适当的处理,下面的简单例子可以说明这一点:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
>>> class MySeq:
... def __getitem__(self, index):
... return index
...
>>> s = MySeq()
>>> s[1]
1
>>> s[1:4]
slice(1, 4, None)
>>> s[1:4:2]
slice(1, 4, 2)
>>> s[1:4:2, 9]
(slice(1, 4, 2), 9)
>>> s[1:4:2, 7:9]
(slice(1, 4, 2), slice(7, 9, None))

可以看到,如果 [] 中有逗号,那么 __getitem__ 收到的是元祖。slice 提供一个 indices(len) 方法,对于给定长度为 len 的序列,计算 S 表示的扩展切片的起始(start)和结尾索引(stop)、以及步幅(stride)。超出边界的索引会被截掉。indices 方法开放了内置序列实现的棘手逻辑,用于优雅地处理缺失索引和负数索引,以及长度超过目标序列的切片。该方法会整顿元祖,把 start、up、stride 都变成非负数,而且都落在指定长度序列的边界内。例如:

1
2
3
4
5
SyntaxError: invalid character ')' (U+FF09)
>>> slice(None, 10, 2).indices(5)
(0, 5, 2)
>>> slice(-3, None, None).indices(5)
(2, 5, 1)

在 Vector 类中无需使用 slice.indices() 方法,因为我们直接将切片参数委托给 _components 数组处理。当你没有底层序列类型作为依靠时,那么使用这个方法能够节省大量时间。接下来重新实现能够处理切片的 __getitem__ 方法:

1
2
3
4
5
6
7
8
9
def __getitem__(self, position):
cls = type(self)
if isinstance(position, slice):
return cls(self._components[position])
elif isinstance(position, numbers.Integral):
return self._components[position]
else:
msg = '{cls.__name__} indices must be intergers'
raise TypeError(msg.format(cls=cls))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
>>> import vector_v2
>>> v1 = vector_v2.Vector(range(10))
>>> v1[1]
1.0
>>> v1[1:3]
Vecotr([1.0, 2.0])
>>> v1[1:5:2]
Vecotr([1.0, 3.0])
>>> v1[1:2]
Vecotr([1.0])
>>> v1[1,2]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/fuchencong/data/workspace/code/private/fluent_python/vector_v2.py", line 57, in __getitem__
raise TypeError(msg.format(cls=cls))
TypeError: Vector indices must be intergers

动态存取属性

现在的 Vector 类可有有大量分量,但是无法通过名称访问向量的分量了。如果能通过单个字母来访问前几个分量的话会比较方便。例如用 x、y、z 代替 v[0]、v[1]、v[2]。通过特殊方法 __getattr__ 可以实现这一点。

属性查找失败后,解释器会调用 __getattr__ 方法。属性查找机制可以简单归纳如下,以 my_obj.x 表达式为例:

  • 解释器首先查找 my_obj 实例有没有 x 属性
  • 如果没有,到类 my_obj.class 中查找
  • 如果还没有,顺着继承树继续查找
  • 如果依旧查找不到,调用 my_obj 所属类中定义的 __getattr__ 方法,传入 self 和属性名称的字符串形式

如下实现了 __getattr__ 方法:

1
2
3
4
5
6
7
8
9
10
shortcut_names = 'xyzt'

def __getattr__(self, name):
cls = type(self)
if len(name) == 1:
pos = cls.shortcut_names.find(name)
if 0 <= pos < len(cls.shortcut_names):
return self._components[pos]
msg = '{.__name__!r} object has no attribute {!r}'
raise AttributeError(msg.format(cls, name))

简单的测试如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
>>> import vector_v3
>>> v1 = vector_v3.Vector(range(10))
>>> v1.x
0.0
>>> v1.y
1.0
>>> v1.z
2.0
>>> v1.t
3.0
>>> v1.w
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/fuchencong/data/workspace/code/private/fluent_python/vector_v3.py", line 68, in __getattr__
raise AttributeError(msg.format(cls, name))
AttributeError: 'Vector' object has no attribute 'w'

在看下面这个使用例子:

1
2
3
4
5
6
7
>>> v1.x
0.0
>>> v1.x = 10
>>> v1
Vecotr([0.0, 1.0, 2.0, 3.0, 4.0, ...])
>>> v1.x
10

这是因为当使用 v1.x = 10 赋值语句后,v 对象就有了 x 属性。因此 v1 对象就有了 x 属性了,因此使用 v1.x 就不会再调用 __getattr__ 方法了。为了避免这种前后矛盾的现象,需要改写 Vector 类中设置属性的逻辑,即 __setattr__ 方法,大多数时候如果定义了 __getattr__ 方法,那么也要定义 __setattr__ 方法,这样才能避免行为不一致:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def __setattr__(self, name, value):
cls = type(self)
if len(name) == 1:
if name in cls.shortcut_names:
error = 'readonly attribute {attr_name!r}'
elif name.islower():
error = "can't set attribute a to z in {cls_name!r}"
else:
error = ''

if error:
msg = error.format(cls_name=cls.__name__, attr_name=name)
raise AttributeError(msg)
super().__setattr__(name, value)

该方法对于非单个小写英文字母的属性,仍然提供标准些行为,这是通过在超类中调用 __setattr__ 实现的。程序测试结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
Vecotr([1.0])
>>> import vector_v3
>>> v1 = vector_v3.Vector(range(10))
>>> v1.x = 10
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/fuchencong/data/workspace/code/private/fluent_python/vector_v3.py", line 82, in __setattr__
raise AttributeError(msg)
AttributeError: readonly attribute 'x'
>>> v1.a = 10
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/fuchencong/data/workspace/code/private/fluent_python/vector_v3.py", line 82, in __setattr__
raise AttributeError(msg)
AttributeError: can't set attribute a to z in 'Vector'
>>> v1.value = 10
>>> v1.value
10

之前介绍过,在类中声明 __slots__ 属性可以防止设置新的实例属性,但是当时也说过,__slots__ 属性只应该用于节省内存,不建议只为了避免创建实例属性而使用 __slots__

如果想允许修改分量,可以使用 __setitem__ 方法,以支持 v[0]= XXX 的赋值方式,或者实现 __setattr__ 方法,以支持 v.x = XXX 的赋值方式。但是为了让 Vector 成为可散列类型,我们没有实现这些方法,以保持 Vector 是不可变的。

散列和快速等值测试

接下来实现 Vector 类型的 __hash__ 方法,加上现有的 __eq__ 方法,这可以把 Vector 实例变成可散列的对象。__hash__ 方法使用 ^ 运算符计算各个分量的散列值,要实现该运算,可以有多重写法,例如:

1
2
3
4
5
6
7
8
9
10
11
12
>>> n = 0
>>> for i in range(1, 6):
... n ^= i
...
>>> n
1
>>> import functools
>>> functools.reduce(lambda a, b: a^b, range(6))
1
>>> import operator
>>> functools.reduce(operator.xor, range(6))
1

最后一种方法最为简单,operator 模块以函数的形式,提供了 python 的全部中缀运算符,从而减少了 lambda 表达式。如下给出了 __hash__ 方法的实现:

1
2
3
def __hash__(self):
hashes = (hash(x) for x in self._components)
return functools.reduce(operator.xor, hashes, 0)

这里再 reduce 中提供了第三个参数 initializer。如果序列为空,initializer 是返回的结果,否则在规约中使用它作为第一个参数,因此应该使用恒等式。对于 +、|、* 来说,initializer 应该是 0,而对于 *& 来说,应该是 1。

另外,为了提高 Vector 类中 __eq__ 方法的效率(原方法需要复制操作对象,以创建元祖进行对比,当 Vector 包含上万个分量时效率不高),重写 __eq__ 方法:

1
2
3
4
5
6
7
8
def __eq__(self, other):
if len(self) != len(other):
return False

for a, b in zip(self, other):
if a != b:
return False
return True

这里 zip 函数接收可迭代对象作为参数,返回一个由元祖构成的生成器,元祖中的元素就是来自于参数中可迭代对象。zip 函数能够轻松地并行迭代两个或多个可迭代对象,它返回的元祖可以拆包成变量,分别对应各个并行输入对象中的一个元素。这里比较长度是必要的,因为 zip 函数处理过程中,一旦有一个输入耗尽,zip 函数会立即停止生成值。如果要保持迭代,直至最长的可迭代对象耗尽,可以使用 itertools.zip_longest,它提供可选的 fillvalue(默认为 None)填充缺失的值:

1
2
3
4
5
6
>>> list(zip(range(1,3), 'ABCEDF'))
[(1, 'A'), (2, 'B')]
>>> from itertools import zip_longest
>>> list(zip_longest(range(1,3), 'ABCEDF', -1))
>>> list(zip_longest(range(1,3), 'ABCEDF', fillvalue=-1))
[(1, 'A'), (2, 'B'), (-1, 'C'), (-1, 'E'), (-1, 'D'), (-1, 'F')]

关于 __eq__ 函数,还有一种更简单的写法:

1
2
3
def __eq__(self, other):
return len(self) == len(other) and \
all(a == b for a, b in zip(self, other))

除了 zip 函数,为了避免在 for 循环中手动处理索引变量,还经常使用内置的 enumerate 生成器函数:

1
2
3
4
5
6
7
8
9
>>> for i, v in enumerate('abcdef'):
... print(i, v)
...
0 a
1 b
2 c
3 d
4 e
5 f

格式化

最后,为了使 Vector 类支持以球面坐标的方式格式化,重写了其 __format__ 函数,使其支持 h 格式代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def angle(self, n):
r= math.sqrt(sum(x * x for x in self[n:]))
a = math.atan2(r, self[n-1])
if (n == len(self) - 1) and self[-1] < 0:
return math.pi * 2 - a
else:
return a

def angles(self):
return (self.angle(n) for n in range(1, len(self)))

def __format__(self, fmt_spec=''):
if fmt_spec.endswith('h'):
fmt_spec = fmt_spec[:-1]
coords = itertools.chain([abs(self)], self.angles())
outer_fmt = '<{}>'
else:
coords = self
outer_fmt = '({})'
components = (format(c, fmt_spec) for c in coords)
return outer_fmt.format(', '.join(components))

这里使用了 itertools.chain 函数生成生成器表达式,用于无缝迭代向量的模和各个角坐标。测试如下

1
2
3
4
5
6
7
8
>>> import vector_v5
>>> v = vector_v5.Vector(range(3))
>>> format(v, '.3f')
'(0.000, 1.000, 2.000)'
>>> format(v)
'(0.0, 1.0, 2.0)'
>>> format(v, '.3fh')
'<2.236, 1.571, 1.107>'