文章

Python Q&A: “重载”与动态分派

顺带把第二期写了吧。

Python Q&A: “重载”与动态分派

Question.

Python似乎没有其他语言(C++与Java)所说的“重载”机制,还是说它的重载表现的与其他语言完全不同?如果我设置的这个function能接受可变数量的参数,甚至能根据参数的类型动态实现不同的方法,而不是在函数内写一串 if 或者match - case,那该多好。

Answer.

Python 确实没有所谓“重载”机制——明确一下,Java、C++ 语境的重载是“同函数名,不同逻辑”,例如这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class OverloadExample {  
  
    public int add(int a, int b) {  
        return a + b;  
    }
  
    public double add(double a, double b) {  
        return a + b;  
    }
  
    public static void main(String[] args) {  
        OverloadExample example = new OverloadExample();  
  
        // 调用不同的add方法  
        System.out.println("两个整数的和: " + example.add(5, 10));  
        System.out.println("两个浮点数的和: " + example.add(5.5, 10.5));  
    }
}

但在 Python 你做不到,因为后面定义的同名函数会覆盖掉前面定义的函数。例如这个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
>>> # 该代码运行于 Python REPL
>>> def fun(param1, param2):
...     print(f'hello {param1} from {param2} !')
...     
>>> fun('hyli360', 'NWU')
hello hyli360 from NWU !
>>> # 定义一个同名函数,但具有不同参数签名
>>> def fun(param1, param2, param3):
...     print(f'hello {param1} from {param2} on {param3}!')
...     
>>> # 此时 Python 只会记住新的参数签名,而不是上一个签名
>>> fun('hyli360', 'NWU', 'Linux')
hello hyli360 from NWU on Linux!
>>> fun('hyli360', 'NWU')
Traceback (most recent call last):
  File "<python-input-13>", line 1, in <module>
    fun('hyli360', 'NWU')
    ~~~^^^^^^^^^^^^^^^^^^
TypeError: fun() missing 1 required positional argument: 'param3'

笔者在接触更 Pythonic(更有 Python 味)的做法之前,确实也有用 ifmatch - case;情况更多的时候,就将情形与调用的函数写成键值对,通过查表动态调用。性能是有保证了,但代码看上去仍不是很舒服,且割裂感很重。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def fun1(param1, param2):
    print(f"fun1 has been called! received params: {param1} and {param2}")

def fun2(param1, param2):
    print(f"fun2 has been called! received params: {param1} and {param2}")

fun_table = {
    True: fun1,
    False: fun2,
}

def fun_dispatch(condition: bool):
    param1 = 'param1'
    param2 = 'param2'
    
    fun_table[condition](param1, param2)

顺带一提,这种编程理念也叫“表驱动编程”(Table-Driven Methods),在其他语言里其实非常常见。例如,在 C 语言上写一个按月份返回当月天数的算法:

1
2
3
4
5
static int monthDays[12] = {31,28,31,30,31,30,31,31,30,31,30,31};

int iGetMonthDays(int iMonth){
    return monthDays[(iMonth - 1)];
}

不过,Python 其实早就对这种问题给出了自己的解法,而且其实还不错——基于装饰器的动态分派。如果不了解装饰器,可以先看一下上一篇的 Q&A 1,然后继续看接下来的内容。

Python 自己已经提供了一个基础版的动态分派装饰器:functools.singledispatch,也就是基于第一个参数的类型进行分派(简称“单分派”)。

例如我们有一个 FASTA 数据抽取器(fasta_extractor),它可以根据我们提供的序列 ID 列表,抽选带有该 ID 的序列。我们希望它能够同时读取这些类型的数据:

  • Path 对象,也就是文件路径;
  • list[SeqRecord] 或者 dict[str, SeqRecord] 对象(SeqRecord 列表或 {SeqRecord.id: SeqRecord} 键值对)。
1
2
3
4
5
6
7
8
9
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO
from pathlib import Path

def fasta_extractor(
    fasta: Path | list[SeqRecord] | dict[str, SeqRecord],
    entries_id: list,
):
    pass

可以看到,函数的参数签名(函数名后面的括号部分)里,任意一个参数可以分配不同的类型签名,但后面仍需要根据其类型(isinstance())进行分流。

使用 functools.singledispatch 就可以改善这个问题:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO
from pathlib import Path
from functools import singledispatch

# 这里定义下面几个分派都不符合的情形
# 比较好的工程实践是,如果分派失败,抛出 TypeError,从而能够确定哪个分派分支没有定义
@singledispatch
def fasta_extractor(
    fasta: Path | list[SeqRecord] | dict[str, SeqRecord],
    entries_id: list,
):
    raise TypeError(f'fasta is a {type(fasta)}, which has no dispatch branch to handle it')

@fasta_extractor.register
def _(fasta: Path, entries_id):
    # 检查路径是否存在
    # 使用 SeqIO.parse 尝试读取
    # ......
    pass

@fasta_extractor.register
def _(fasta: list, entries_id):
    # ......
    pass

@fasta_extractor.register
def _(fasta: dict, entries_id):
    # ......
    pass

如果要作用到类方法,就不能用 singledispatch 了,而是 singledispatchmethod

另外,这个方法的局限也很明显——你只能对第一个参数的类型进行分派,而无法实现多参数类型分派,这需要第三方库 multipledispatch 支持,而且后续要处理(参数1类型数量 * 参数2类型数量 * …… * 参数 N 类型数量)个分派分支,复杂到不如多定义几个函数。谁需要谁去用吧……

另外的另外,其实标准库 typing 还真提供了一个名字为“重载”(@overload)的装饰器,但这个玩意实在很难与重载搭上任何关系,因为 typing 库的功能(类型提示)就意味着它只能起到一点类型检查的作用,甚至都不如上面这种:

1
Path | list[SeqRecord] | dict[str, SeqRecord]

@overload,只会把上面那个变成这个样子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO
from pathlib import Path
from typing import overload

@overload
def fasta_extractor(fasta: Path, entries_id: list,): ...
@overload
def fasta_extractor(fasta: Path, entries_id: list[SeqRecord],): ...
@overload
def fasta_extractor(fasta: Path, entries_id: dict[str, SeqRecord],): ...

def fasta_extractor(fasta, entries_id,):
    # 仍然需要根据输入的对象类型分门别类地处理!
    pass

……这个也爱谁用谁用吧。

(补充。typing.overload 解决的并非分派/重载问题,而是输入-输出类型问题。如果输入类型与输出类型绑定,例如“你给我 list,我就还你 dict”,使用这个装饰器进行声明就能让静态检查器搞清楚状况,从而避免奇怪的类型警告。)

1
2
@overload
def fasta_extractor(fasta: list[SeqRecord], entries_id: list,) -> dict: ...