这篇文章将自定义一个序列类型 Vector,进一步探讨 Python 序列的修改、散列和切片等操作所涉及的背后原理。
Vector 类:用户定义的序列类型 这里实现一个 Vector 类,这个类的行为与 Python 中标准的不可变扁平序列基本一致:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 from array import arrayimport reprlibimport mathclass Vector : typecode = 'd' def __init__ (self, components ): self._components = array(self.typecode, components) def __iter__ (self ): return iter (self._components) def __repr__ (self ): components = reprlib.repr (self._components) components = components[components.find('[' ):-1 ] return 'Vecotr({})' .format (components) def __str__ (self ): return str (tuple (self)) def __bytes__ (self ): return bytes (ord (self.typecode)) + bytes (self._components) def __eq__ (self, other ): return tuple (self) == tuple (other) def __abs__ (self ): return math.sqrt(sum (x * x for x in self)) def __bool__ (self ): return bool (abs (self)) @classmethod def frombytes (cls, octets ): typecode = chr (octets[0 ]) memv = memoryview (octets[1 :].cast(typecode)) return cls(memv)
1 2 3 4 5 6 7 >>> import vector_v1>>> v1 = vector_v1.Vector((2 , 3 , 4 ))>>> v2 = vector_v1.Vector(range (10 ))>>> v1Vecotr([2.0 , 3.0 , 4.0 ]) >>> v2Vecotr([0.0 , 1.0 , 2.0 , 3.0 , 4.0 , ...])
这里 _components
是一个受保护的实例属性
序列类型的构造方法最好接收可迭代的对象作为参数,因为所有内置的序列类型都是这样做的
字符串的表示形式是用于调试,使用 reprlib 模块可以生成长度有限的表示形式。reprlib.repr 用于生成大型结构或递归结构的安全形式,它会限制输出字符串的长度。
协议和鸭子类型 在 Python 中,创建功能完善的序列类型无需使用继承,只需实现符合序列协议的方法。在面对对象编程中,协议是非正式的接口,只在文档中定义,在代码中不定义。例如 Python 序列协议只要实现了 __len__
和 __getitem__
两个方法,就能用于任何期待序列的地方。我们说它是序列,因为它的行为像序列,人们将其称为鸭子类型。
协议是非正式的,没有强制力,如果你知道你的类的具体使用场景,通常只需要实现一个协议的部分。例如如果只是为了支持迭代,只需要实现 __getitem__
方法,没有必要提供 __len__
方法。
Vector 类第 2 版:可切片的序列 只需要添加如下方法,Vector 类就支持了序列协议:
1 2 3 4 5 def __len__ (self ): return len (self._components) def __getitem__ (self, position ): return self._components[position]
1 2 3 4 5 6 7 8 9 10 >>> import vector_v2>>> v3 = vector_v2.Vector(range (5 ))>>> v3Vecotr([0.0 , 1.0 , 2.0 , 3.0 , 4.0 ]) >>> v3[1 ]1.0 >>> v3[2 :3 ]array('d' , [2.0 ]) >>> len (v3)5
现在 Vector 已经支持切片了,但是切片返回的实例是 array 对象,而不是 Vector 实例。内置的序列类型,其切片结果也都是各自类型的新实例,而不是其他类型。为了让 Vector 实例的切片仍然是 Vector 的实例,就不能简单地委托给数组切片。需要分析传递给 getitem 方法的参数,做适当的处理,下面的简单例子可以说明这一点:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 >>> class MySeq :... def __getitem__ (self, index ):... return index... >>> s = MySeq()>>> s[1 ]1 >>> s[1 :4 ]slice (1 , 4 , None )>>> s[1 :4 :2 ]slice (1 , 4 , 2 )>>> s[1 :4 :2 , 9 ](slice (1 , 4 , 2 ), 9 ) >>> s[1 :4 :2 , 7 :9 ](slice (1 , 4 , 2 ), slice (7 , 9 , None ))
可以看到,如果 []
中有逗号,那么 __getitem__
收到的是元祖。slice 提供一个 indices(len) 方法,对于给定长度为 len 的序列,计算 S 表示的扩展切片的起始(start)和结尾索引(stop)、以及步幅(stride)。超出边界的索引会被截掉。indices 方法开放了内置序列实现的棘手逻辑,用于优雅地处理缺失索引和负数索引,以及长度超过目标序列的切片。该方法会整顿元祖,把 start、up、stride
都变成非负数,而且都落在指定长度序列的边界内。例如:
1 2 3 4 5 SyntaxError: invalid character ')' (U+FF09) >>> slice (None , 10 , 2 ).indices(5 )(0 , 5 , 2 ) >>> slice (-3 , None , None ).indices(5 )(2 , 5 , 1 )
在 Vector 类中无需使用 slice.indices()
方法,因为我们直接将切片参数委托给 _components
数组处理。当你没有底层序列类型作为依靠时,那么使用这个方法能够节省大量时间。接下来重新实现能够处理切片的 __getitem__
方法:
1 2 3 4 5 6 7 8 9 def __getitem__ (self, position ): cls = type (self) if isinstance (position, slice ): return cls(self._components[position]) elif isinstance (position, numbers.Integral): return self._components[position] else : msg = '{cls.__name__} indices must be intergers' raise TypeError(msg.format (cls=cls))
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 >>> import vector_v2>>> v1 = vector_v2.Vector(range (10 ))>>> v1[1 ]1.0 >>> v1[1 :3 ]Vecotr([1.0 , 2.0 ]) >>> v1[1 :5 :2 ]Vecotr([1.0 , 3.0 ]) >>> v1[1 :2 ]Vecotr([1.0 ]) >>> v1[1 ,2 ]Traceback (most recent call last): File "<stdin>" , line 1 , in <module> File "/Users/fuchencong/data/workspace/code/private/fluent_python/vector_v2.py" , line 57 , in __getitem__ raise TypeError(msg.format (cls=cls)) TypeError: Vector indices must be intergers
动态存取属性 现在的 Vector 类可有有大量分量,但是无法通过名称访问向量的分量了。如果能通过单个字母来访问前几个分量的话会比较方便。例如用 x、y、z 代替 v[0]、v[1]、v[2]。通过特殊方法 __getattr__
可以实现这一点。
属性查找失败后,解释器会调用 __getattr__
方法。属性查找机制可以简单归纳如下,以 my_obj.x 表达式为例:
解释器首先查找 my_obj 实例有没有 x 属性
如果没有,到类 my_obj.class 中查找
如果还没有,顺着继承树继续查找
如果依旧查找不到,调用 my_obj 所属类中定义的 __getattr__
方法,传入 self 和属性名称的字符串形式
如下实现了 __getattr__
方法:
1 2 3 4 5 6 7 8 9 10 shortcut_names = 'xyzt' def __getattr__ (self, name ): cls = type (self) if len (name) == 1 : pos = cls.shortcut_names.find(name) if 0 <= pos < len (cls.shortcut_names): return self._components[pos] msg = '{.__name__!r} object has no attribute {!r}' raise AttributeError(msg.format (cls, name))
简单的测试如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 >>> import vector_v3>>> v1 = vector_v3.Vector(range (10 ))>>> v1.x0.0 >>> v1.y1.0 >>> v1.z2.0 >>> v1.t3.0 >>> v1.wTraceback (most recent call last): File "<stdin>" , line 1 , in <module> File "/Users/fuchencong/data/workspace/code/private/fluent_python/vector_v3.py" , line 68 , in __getattr__ raise AttributeError(msg.format (cls, name)) AttributeError: 'Vector' object has no attribute 'w'
在看下面这个使用例子:
1 2 3 4 5 6 7 >>> v1.x0.0 >>> v1.x = 10 >>> v1Vecotr([0.0 , 1.0 , 2.0 , 3.0 , 4.0 , ...]) >>> v1.x10
这是因为当使用 v1.x = 10
赋值语句后,v 对象就有了 x 属性。因此 v1 对象就有了 x 属性了,因此使用 v1.x
就不会再调用 __getattr__
方法了。为了避免这种前后矛盾的现象,需要改写 Vector 类中设置属性的逻辑,即 __setattr__
方法,大多数时候如果定义了 __getattr__
方法,那么也要定义 __setattr__
方法,这样才能避免行为不一致:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 def __setattr__ (self, name, value ): cls = type (self) if len (name) == 1 : if name in cls.shortcut_names: error = 'readonly attribute {attr_name!r}' elif name.islower(): error = "can't set attribute a to z in {cls_name!r}" else : error = '' if error: msg = error.format (cls_name=cls.__name__, attr_name=name) raise AttributeError(msg) super ().__setattr__(name, value)
该方法对于非单个小写英文字母的属性,仍然提供标准些行为,这是通过在超类中调用 __setattr__
实现的。程序测试结果如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 Vecotr([1.0 ]) >>> import vector_v3>>> v1 = vector_v3.Vector(range (10 ))>>> v1.x = 10 Traceback (most recent call last): File "<stdin>" , line 1 , in <module> File "/Users/fuchencong/data/workspace/code/private/fluent_python/vector_v3.py" , line 82 , in __setattr__ raise AttributeError(msg) AttributeError: readonly attribute 'x' >>> v1.a = 10 Traceback (most recent call last): File "<stdin>" , line 1 , in <module> File "/Users/fuchencong/data/workspace/code/private/fluent_python/vector_v3.py" , line 82 , in __setattr__ raise AttributeError(msg) AttributeError: can't set attribute a to z in ' Vector' >>> v1.value = 10 >>> v1.value 10
之前介绍过,在类中声明 __slots__
属性可以防止设置新的实例属性,但是当时也说过,__slots__
属性只应该用于节省内存,不建议只为了避免创建实例属性而使用 __slots__
。
如果想允许修改分量,可以使用 __setitem__
方法,以支持 v[0]= XXX
的赋值方式,或者实现 __setattr__
方法,以支持 v.x = XXX
的赋值方式。但是为了让 Vector 成为可散列类型,我们没有实现这些方法,以保持 Vector 是不可变的。
散列和快速等值测试 接下来实现 Vector 类型的 __hash__
方法,加上现有的 __eq__
方法,这可以把 Vector 实例变成可散列的对象。__hash__
方法使用 ^
运算符计算各个分量的散列值,要实现该运算,可以有多重写法,例如:
1 2 3 4 5 6 7 8 9 10 11 12 >>> n = 0 >>> for i in range (1 , 6 ):... n ^= i... >>> n1 >>> import functools>>> functools.reduce(lambda a, b: a^b, range (6 ))1 >>> import operator>>> functools.reduce(operator.xor, range (6 ))1
最后一种方法最为简单,operator 模块以函数的形式,提供了 python 的全部中缀运算符,从而减少了 lambda 表达式。如下给出了 __hash__
方法的实现:
1 2 3 def __hash__ (self ): hashes = (hash (x) for x in self._components) return functools.reduce(operator.xor, hashes, 0 )
这里再 reduce 中提供了第三个参数 initializer。如果序列为空,initializer 是返回的结果,否则在规约中使用它作为第一个参数,因此应该使用恒等式。对于 +、|、*
来说,initializer 应该是 0,而对于 *
和 &
来说,应该是 1。
另外,为了提高 Vector 类中 __eq__
方法的效率(原方法需要复制操作对象,以创建元祖进行对比,当 Vector 包含上万个分量时效率不高),重写 __eq__
方法:
1 2 3 4 5 6 7 8 def __eq__ (self, other ): if len (self) != len (other): return False for a, b in zip (self, other): if a != b: return False return True
这里 zip 函数接收可迭代对象作为参数,返回一个由元祖构成的生成器,元祖中的元素就是来自于参数中可迭代对象。zip 函数能够轻松地并行迭代两个或多个可迭代对象,它返回的元祖可以拆包成变量,分别对应各个并行输入对象中的一个元素。这里比较长度是必要的,因为 zip 函数处理过程中,一旦有一个输入耗尽,zip 函数会立即停止生成值。如果要保持迭代,直至最长的可迭代对象耗尽,可以使用 itertools.zip_longest
,它提供可选的 fillvalue(默认为 None)填充缺失的值:
1 2 3 4 5 6 >>> list (zip (range (1 ,3 ), 'ABCEDF' ))[(1 , 'A' ), (2 , 'B' )] >>> from itertools import zip_longest>>> list (zip_longest(range (1 ,3 ), 'ABCEDF' , -1 ))>>> list (zip_longest(range (1 ,3 ), 'ABCEDF' , fillvalue=-1 ))[(1 , 'A' ), (2 , 'B' ), (-1 , 'C' ), (-1 , 'E' ), (-1 , 'D' ), (-1 , 'F' )]
关于 __eq__
函数,还有一种更简单的写法:
1 2 3 def __eq__ (self, other ): return len (self) == len (other) and \ all (a == b for a, b in zip (self, other))
除了 zip 函数,为了避免在 for 循环中手动处理索引变量,还经常使用内置的 enumerate 生成器函数:
1 2 3 4 5 6 7 8 9 >>> for i, v in enumerate ('abcdef' ):... print (i, v)... 0 a1 b2 c3 d4 e5 f
格式化 最后,为了使 Vector 类支持以球面坐标的方式格式化,重写了其 __format__
函数,使其支持 h
格式代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 def angle (self, n ): r= math.sqrt(sum (x * x for x in self[n:])) a = math.atan2(r, self[n-1 ]) if (n == len (self) - 1 ) and self[-1 ] < 0 : return math.pi * 2 - a else : return a def angles (self ): return (self.angle(n) for n in range (1 , len (self))) def __format__ (self, fmt_spec='' ): if fmt_spec.endswith('h' ): fmt_spec = fmt_spec[:-1 ] coords = itertools.chain([abs (self)], self.angles()) outer_fmt = '<{}>' else : coords = self outer_fmt = '({})' components = (format (c, fmt_spec) for c in coords) return outer_fmt.format (', ' .join(components))
这里使用了 itertools.chain 函数生成生成器表达式,用于无缝迭代向量的模和各个角坐标。测试如下
1 2 3 4 5 6 7 8 >>> import vector_v5>>> v = vector_v5.Vector(range (3 ))>>> format (v, '.3f' )'(0.000, 1.000, 2.000)' >>> format (v)'(0.0, 1.0, 2.0)' >>> format (v, '.3fh' )'<2.236, 1.571, 1.107>'