1 | np.sort(a, kind='mergesort') |
排序是计算机科学最重要的内容之一,NumPy作为大名鼎鼎的数值运算库对排序也有良好的支持。然而如果你一时好奇想一窥NumPy的排序源码,看看是不是有什么惊天动地的优化,却没那么容易。因为NumPy的排序功能完全由C实现,只靠pip install numpy
得到的二进制文件是无法阅读这部分源码的。想知道上面这行排序代码背后是什么故事,只能去GitHub看官方的repo。
本文将从两个角度来解读NumPy(版本:1.15)中归并排序的源码:
- 算法角度——NumPy是怎样进行归并排序的。
- 工程角度——这一排序算法是如何成为Python中易用的接口的。
归并排序的算法分析
我们首先来看看归并排序mergesort.c
中对整数(int)进行排序的代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59static void
mergesort0_int(npy_int *pl, npy_int *pr, npy_int *pw)
{
npy_int vp, *pi, *pj, *pk, *pm;
if (pr - pl > SMALL_MERGESORT) {
/* merge sort */
pm = pl + ((pr - pl) >> 1);
mergesort0_int(pl, pm, pw);
mergesort0_int(pm, pr, pw);
for (pi = pw, pj = pl; pj < pm;) {
*pi++ = *pj++;
}
pi = pw + (pm - pl);
pj = pw;
pk = pl;
while (pj < pi && pm < pr) {
if (INT_LT(*pm, *pj)) {
*pk++ = *pm++;
}
else {
*pk++ = *pj++;
}
}
while(pj < pi) {
*pk++ = *pj++;
}
}
else {
/* insertion sort */
for (pi = pl + 1; pi < pr; ++pi) {
vp = *pi;
pj = pi;
pk = pi - 1;
while (pj > pl && INT_LT(vp, *pk)) {
*pj-- = *pk--;
}
*pj = vp;
}
}
}
int
mergesort_int(void *start, npy_intp num, void *NOT_USED)
{
npy_int *pl, *pr, *pw;
pl = start;
pr = pl + num;
pw = malloc((num/2) * sizeof(npy_int));
if (pw == NULL) {
return -NPY_ENOMEM;
}
mergesort0_int(pl, pr, pw);
free(pw);
return 0;
}
这段排序代码的入口是mergesort_int
。这个函数对真正的递归归并排序算法mergesort0_int
进行了包装,包含三个参数。不难想到mergesort_int
第一个参数是待排序数组的起始地址,第二个参数是待排序数组的大小。而第三个参数写的很清楚是不使用的,是一个工程问题,下面会提到。
算法首先规定了两个整数指针pl
(left)、pr
(right)分别指向待排序数组的起始和末尾,随后分配了大小为待排序数组大小一半的buffer用于排序(归并排序是需要额外空间的),最后调用排序算法mergesort0_int
。进入真正的排序算法mergsort0_int
之后程序首先对待排序数组的大小进行了分析,如果小于规定的阈值SMALL_MERGESORT
就会利用插入排序进行排序,否则再进行归并排序。插入排序的时间复杂度虽然是\(\mathcal{O}(N^2)\),但是在数组较小时overhead非常小,数据连续性也很棒,效果想来是相当不错。这可以算是NumPy
的归并排序比起我们做练习写的垃圾归并排序代码的第一个优化了。
可能有同学会问,这个SMALL_MERGESORT
究竟是多大呢?在mergesort.c
的头部进行了如下的定义:1
是不是比你预想的要大一些?插入排序在基本有序数组上也有着恐怖表现,真是杀人越货居家旅行必备呀。
插入排序的代码没什么好看的,我们来看当数据规模大于SMALL_MERGESORT
时的归并排序代码。首先注意到pm
这样一个变量。这个整数指针指向的是归并排序进行分治时两个子数组的交界位置,计算其值的一般实现是:1
pm = (pl + pr) / 2;
而NumPy里的实现是:1
pm = pl + ((pr - pl) >> 1);
这不得不让人想到经典的“二分查找十年Bug”梗。把这些细节都做好真的不容易,NumPy作为一个数值运算库这一点上还是很专业的。
随后我们再来看递归调用mergesort0_int
对两个子区间分别进行排序的代码,注意到除了规定数组的两个指针外,作为排序辅助空间的数组也被传了进去。这意味着在这一系列递归的排序中,算法使用的辅助空间是同一个。这粗想起来好像说不通,其实程序还是线性执行的,没有任何并行,所以不会有任何问题。而一般的归并排序递归实现往往都在当前函数堆栈中申请辅助空间,自然也就引入了系统调用的额外时间开销。如果没有对数组较小的情况进行特殊处理(如NumPy中一样)的话,系统调用的开销会是相当巨大的。这是NumPy归并排序的第二个优化。
接下来的代码就是普通的归并排序了,虽然符号都很抽象,但搞清楚含义后都不难懂。我本科学数据结构的时候在归并时遇到的是这样的代码:1
2
3
4
5
6
7
8
9
10
11
12template <typename T> //有序向量的归并
void Vector<T>::merge ( Rank lo, Rank mi, Rank hi ) { //各自有序的子向量[lo, mi)和[mi, hi)
T* A = _elem + lo; //合并后的向量A[0, hi - lo) = _elem[lo, hi)
int lb = mi - lo; T* B = new T[lb]; //前子向量B[0, lb) = _elem[lo, mi)
for ( Rank i = 0; i < lb; B[i] = A[i++] ); //复制前子向量
int lc = hi - mi; T* C = _elem + mi; //后子向量C[0, lc) = _elem[mi, hi)
for ( Rank i = 0, j = 0, k = 0; ( j < lb ) || ( k < lc ); ) { //B[j]和C[k]中的小者续至A末尾
if ( ( j < lb ) && ( ! ( k < lc ) || ( B[j] <= C[k] ) ) ) A[i++] = B[j++];
if ( ( k < lc ) && ( ! ( j < lb ) || ( C[k] < B[j] ) ) ) A[i++] = C[k++];
}
delete [] B; //释放临时空间B
} //归并后得到完整的有序向量[lo, hi)
感觉看起来还是乱一些……
排序算法的工程分析
排序算法文件的结构
NumPy实现了快速排序、归并排序、堆排序三种排序方式,源码位于numpy/core/src/npysort
目录下。如果你打开这个目录看一下的话,你就会发现——根本就没有我前面吹牛的时候提到的mergesort.c
,以及想象中的quicksort.c
、heapsort.c
等。相反,你看到的是一堆*.c.src
文件,如mergesort.c.src
。打开一看,卧槽,这莫非是C++2077?1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16int
mergesort_void *start, npy_intp num, void *NOT_USED) @(
{
@ *pl, *pr, *pw;
pl = start;
pr = pl + num;
pw = malloc((num/2) * sizeof( @));
if (pw == NULL) {
return -NPY_ENOMEM;
}
mergesort0_ @(pl, pr, pw);
free(pw);
return 0;
}
之所以出现了许多由@
包裹的变量,是由于NumPy是基于C的,而C没有泛型。前面我们介绍了对32位有符号整数的归并排序,那么其它整数呢?浮点数呢?C对这类问题只有一个解决方法,那就是对每种变量写对应的函数。为了避免重复劳动,NumPy采用的方案是用c.src
文件来写一个模板,然后在安装包时利用Python批量生成c
文件。在numpy/
目录下就有一个setup.py
的文件,专门规定了各种编译前需要由*.c.src
转化为*.c
的文件。渲染模板的文件位于和分发有关的distutils
目录下,其头部的文档中规定了NumPy中使用的这种模板的语法:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53takes templated file .xxx.src and produces .xxx file where .xxx is
.i or .c or .h, using the following template rules
/**begin repeat -- on a line by itself marks the start of a repeated code
segment
/**end repeat**/ -- on a line by itself marks it's end
After the /**begin repeat and before the */, all the named templates are placed
these should all have the same number of replacements
Repeat blocks can be nested, with each nested block labeled with its depth,
i.e.
/**begin repeat1
*....
*/
/**end repeat1**/
When using nested loops, you can optionally exclude particular
combinations of the variables using (inside the comment portion of the inner loop):
:exclude: var1=value1, var2=value2, ...
This will exclude the pattern where var1 is value1 and var2 is value2 when
the result is being generated.
In the main body each replace will use one entry from the list of named replacements
Note that all #..# forms in a block must have the same number of
comma-separated entries.
Example:
An input file containing
/**begin repeat
* #a = 1,2,3#
* #b = 1,2,3#
*/
/**begin repeat1
* #c = ted, jim#
*/
@a@, @b@, @c@
/**end repeat1**/
/**end repeat**/
produces
line 1 "template.c.src"
/*
*********************************************************************
** This file was autogenerated from a template DO NOT EDIT!!**
** Changes should be made to the original source (.src) file **
*********************************************************************
*/
#line 9
1, 1, ted
#line 9
1, 1, jim
#line 9
2, 2, ted
#line 9
2, 2, jim
#line 9
3, 3, ted
#line 9
3, 3, jim
这种语法其实是非常简洁和强大的,仔细思考的话会发现比jinja2
之类的模板系统用起来方便,因为模板和变量时定义在一个文件当中的。现在我们再来看mergesort.c.src
中是怎样规定@type@
和@suff@
等变量的:1
2
3
4
5
6
7
8
9
10
11
12
13/**begin repeat
*
* #TYPE = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
* LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE,
* CFLOAT, CDOUBLE, CLONGDOUBLE, DATETIME, TIMEDELTA#
* #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong,
* longlong, ulonglong, half, float, double, longdouble,
* cfloat, cdouble, clongdouble, datetime, timedelta#
* #type = npy_bool, npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int,
* npy_uint, npy_long, npy_ulong, npy_longlong, npy_ulonglong,
* npy_ushort, npy_float, npy_double, npy_longdouble, npy_cfloat,
* npy_cdouble, npy_clongdouble, npy_datetime, npy_timedelta#
*/
可见采用模板消灭的重复劳动还是很可观的。只要把INT-int-npy_int
这一组变量代入到上面带有@
的代码,就可以还原出第一节介绍的归并排序代码了。
其实用Python来生成C代码还是很常见的,加上现代的编译分发工具,某种程度上也不失为对C++的挑战。不过话说回来,为什么不能用C++来写一个NumPy2呢?可维护性、可扩展性会不会更好呢?我觉得这个问题的答案是没有必要。当需要精确的底层数据结构操控时,C游刃有余;而一旦脱离了底层NumPy立刻可以无缝切换到比C++开发效率更高的Python。
回到mergesort.c.src
中,我们还会发现除了形如mergesort_@suff@
的排序函数之外,还有一类形如amergesort_@suff@
的排序函数,这类函数是用来进行np.argsort操作的:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61static void
amergesort0_@suff@(npy_intp *pl, npy_intp *pr, @type@ *v, npy_intp *pw)
{
@type@ vp;
npy_intp vi, *pi, *pj, *pk, *pm;
if (pr - pl > SMALL_MERGESORT) {
/* merge sort */
pm = pl + ((pr - pl) >> 1);
amergesort0_@suff@(pl, pm, v, pw);
amergesort0_@suff@(pm, pr, v, pw);
for (pi = pw, pj = pl; pj < pm;) {
*pi++ = *pj++;
}
pi = pw + (pm - pl);
pj = pw;
pk = pl;
while (pj < pi && pm < pr) {
if (@TYPE@_LT(v[*pm], v[*pj])) {
*pk++ = *pm++;
}
else {
*pk++ = *pj++;
}
}
while(pj < pi) {
*pk++ = *pj++;
}
}
else {
/* insertion sort */
for (pi = pl + 1; pi < pr; ++pi) {
vi = *pi;
vp = v[vi];
pj = pi;
pk = pi - 1;
while (pj > pl && @TYPE@_LT(vp, v[*pk])) {
*pj-- = *pk--;
}
*pj = vi;
}
}
}
int
amergesort_@suff@(void *v, npy_intp *tosort, npy_intp num, void *NOT_USED)
{
npy_intp *pl, *pr, *pw;
pl = tosort;
pr = pl + num;
pw = malloc((num/2) * sizeof(npy_intp));
if (pw == NULL) {
return -NPY_ENOMEM;
}
amergesort0_@suff@(pl, pr, v, pw);
free(pw);
return 0;
}
比较两类函数,不难发现他们的总体逻辑都是非常类似的。区别首先在于函数原型,argsort的形参包括一个tosort
的数组,这是用来保存argsort的结果的。在传入之前,tosort
被初始化成了从0到待排序数组大小的数组以映射待排序数组,因此可以通过访问待排序数组下角标为tosort
中元素的值的方式获取tosrot
中元素的值的“大小”信息,简化实现。
除此之外,在*.c.src
中还对两种情况作了特殊处理,第一是字符串排序,第二是Python对象排序,这两种情况都需要特殊元素比较方式和元素拷贝方式,因此进行了比较特殊的实现。
排序算法文件与工程的集成
不同的排序对象、不同的排序算法、是否是argsort,一个因素不同就要重新定义一个函数。这使得numpy/core/src/npysort
中包含了大量的排序函数,这些函数原型都包含在numpy/core/src/common
下的npy_sort.h
头文件中。这一文件似乎并不是依据模板生成的,是NumPy一个可以改进的问题(20181112又及:经本人PR已解决)。简单来讲,其它的模块只需要包含这个头文件,就可以使用各种排序算法了,然而NumPy又对这些算法做了更高级的包装。NumPy首先在numpy/core/include/numpy/ndarraytypes.h
中规定了排序的函数类型:1
2typedef int (PyArray_SortFunc)(void *, npy_intp, void *);
typedef int (PyArray_ArgSortFunc)(void *, npy_intp *, npy_intp, void *);
然后又在同文件中定义了一个名叫PyArray_ArrFuncs
的struct
,其中包含了大量的函数指针:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110typedef struct {
/*
* Functions to cast to most other standard types
* Can have some NULL entries. The types
* DATETIME, TIMEDELTA, and HALF go into the castdict
* even though they are built-in.
*/
PyArray_VectorUnaryFunc *cast[NPY_NTYPES_ABI_COMPATIBLE];
/* The next four functions *cannot* be NULL */
/*
* Functions to get and set items with standard Python types
* -- not array scalars
*/
PyArray_GetItemFunc *getitem;
PyArray_SetItemFunc *setitem;
/*
* Copy and/or swap data. Memory areas may not overlap
* Use memmove first if they might
*/
PyArray_CopySwapNFunc *copyswapn;
PyArray_CopySwapFunc *copyswap;
/*
* Function to compare items
* Can be NULL
*/
PyArray_CompareFunc *compare;
/*
* Function to select largest
* Can be NULL
*/
PyArray_ArgFunc *argmax;
/*
* Function to compute dot product
* Can be NULL
*/
PyArray_DotFunc *dotfunc;
/*
* Function to scan an ASCII file and
* place a single value plus possible separator
* Can be NULL
*/
PyArray_ScanFunc *scanfunc;
/*
* Function to read a single value from a string
* and adjust the pointer; Can be NULL
*/
PyArray_FromStrFunc *fromstr;
/*
* Function to determine if data is zero or not
* If NULL a default version is
* used at Registration time.
*/
PyArray_NonzeroFunc *nonzero;
/*
* Used for arange.
* Can be NULL.
*/
PyArray_FillFunc *fill;
/*
* Function to fill arrays with scalar values
* Can be NULL
*/
PyArray_FillWithScalarFunc *fillwithscalar;
/*
* Sorting functions
* Can be NULL
*/
PyArray_SortFunc *sort[NPY_NSORTS];
PyArray_ArgSortFunc *argsort[NPY_NSORTS];
/*
* Dictionary of additional casting functions
* PyArray_VectorUnaryFuncs
* which can be populated to support casting
* to other registered types. Can be NULL
*/
PyObject *castdict;
/*
* Functions useful for generalizing
* the casting rules.
* Can be NULL;
*/
PyArray_ScalarKindFunc *scalarkind;
int **cancastscalarkindto;
int *cancastto;
PyArray_FastClipFunc *fastclip;
PyArray_FastPutmaskFunc *fastputmask;
PyArray_FastTakeFunc *fasttake;
/*
* Function to select smallest
* Can be NULL
*/
PyArray_ArgFunc *argmin;
} PyArray_ArrFuncs;
这个结构是NumPy的核心结构之一,在官方文档中也有介绍。我们看到,PyArray_ArrFuncs
中有两个和排序函数指针有关的数组,分别是PyArray_SortFunc *sort[NPY_NSORTS];
和PyArray_ArgSortFunc *argsort[NPY_NSORTS];
。至此,在本文第一部分遇到的部分排序函数含有一个冗余参数的现象的原因就呼之欲出了。NumPy在numpy/core/src/multiarray/arraytypes.c.src中对各种类型的array
进行了初始化设定,其中包括了设定PyArray_ArrFuncs
的代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = {
{
@from@_to_BOOL,
@from@_to_BYTE,
@from@_to_UBYTE,
@from@_to_SHORT,
@from@_to_USHORT,
@from@_to_INT,
@from@_to_UINT,
@from@_to_LONG,
@from@_to_ULONG,
@from@_to_LONGLONG,
@from@_to_ULONGLONG,
@from@_to_FLOAT,
@from@_to_DOUBLE,
@from@_to_LONGDOUBLE,
@from@_to_CFLOAT,
@from@_to_CDOUBLE,
@from@_to_CLONGDOUBLE,
@from@_to_OBJECT,
@from@_to_STRING,
@from@_to_UNICODE,
@from@_to_VOID
},
@from@_getitem,
@from@_setitem,
(PyArray_CopySwapNFunc*)@from@_copyswapn,
(PyArray_CopySwapFunc*)@from@_copyswap,
(PyArray_CompareFunc*)@from@_compare,
(PyArray_ArgFunc*)@from@_argmax,
(PyArray_DotFunc*)@from@_dot,
(PyArray_ScanFunc*)@from@_scan,
@from@_fromstr,
(PyArray_NonzeroFunc*)@from@_nonzero,
(PyArray_FillFunc*)@from@_fill,
(PyArray_FillWithScalarFunc*)@from@_fillwithscalar,
#if @sort@
{
quicksort_@suff@,
heapsort_@suff@,
mergesort_@suff@
},
{
aquicksort_@suff@,
aheapsort_@suff@,
amergesort_@suff@
},
#else
{
NULL, NULL, NULL
},
{
NULL, NULL, NULL
},
#endif
NULL,
(PyArray_ScalarKindFunc*)NULL,
NULL,
NULL,
(PyArray_FastClipFunc*)@from@_fastclip,
(PyArray_FastPutmaskFunc*)@from@_fastputmask,
(PyArray_FastTakeFunc*)@from@_fasttake,
(PyArray_ArgFunc*)@from@_argmin
};
这段代码(并不是唯一的设定PyArray_ArrFuncs
的代码,回想我们还要处理字符串等)为PyArray_SortFunc *sort[NPY_NSORTS]
和PyArray_ArgSortFunc *argsort[NPY_NSORTS]
分别初始化了一个数组,包括了所有的排序函数类型。而PyArray_ArrFuncs
最后又被包括进了PyArray_Descr
中,而PyArray_Descr
是我们亲爱的PyArrayObject
的主要成员。该结构是对一个NumPy数组的完整描述。
到这里我们好像已经结束了,其实并没有,因为PyArrayObject
是一个面向C的结构,并不暴露给上层的Python,上层的Python使用的是PyArray_Type
,定义在numpy/core/src/multiarray/arrayobject.c
中。该结构中包含methods
字段,包含了这一个NumPy数组对象可以使用的各种方法(methods)。这些方法定义在numpy/core/src/multiarray/methods.c
中,如排序:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55static PyObject *
array_sort(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
int axis=-1;
int val;
NPY_SORTKIND sortkind = NPY_QUICKSORT;
PyObject *order = NULL;
PyArray_Descr *saved = NULL;
PyArray_Descr *newd;
static char *kwlist[] = {"axis", "kind", "order", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|iO&O:sort", kwlist,
&axis,
PyArray_SortkindConverter, &sortkind,
&order)) {
return NULL;
}
if (order == Py_None) {
order = NULL;
}
if (order != NULL) {
PyObject *new_name;
PyObject *_numpy_internal;
saved = PyArray_DESCR(self);
if (!PyDataType_HASFIELDS(saved)) {
PyErr_SetString(PyExc_ValueError, "Cannot specify " \
"order when the array has no fields.");
return NULL;
}
_numpy_internal = PyImport_ImportModule("numpy.core._internal");
if (_numpy_internal == NULL) {
return NULL;
}
new_name = PyObject_CallMethod(_numpy_internal, "_newnames",
"OO", saved, order);
Py_DECREF(_numpy_internal);
if (new_name == NULL) {
return NULL;
}
newd = PyArray_DescrNew(saved);
Py_DECREF(newd->names);
newd->names = new_name;
((PyArrayObject_fields *)self)->descr = newd;
}
val = PyArray_Sort(self, axis, sortkind);
if (order != NULL) {
Py_XDECREF(PyArray_DESCR(self));
((PyArrayObject_fields *)self)->descr = saved;
}
if (val < 0) {
return NULL;
}
Py_RETURN_NONE;
}
这看起来就是一个很正统的Python函数了,甚至还有一个活灵活现的self
。我们在调用np.sort(a, kind='mergesort')
时,实际在调用a.sort(kind='mergesort')
,更进一步实际在调用上面的static PyObject * array_sort(PyArrayObject *self, PyObject *args, PyObject *kwds)
函数。这个函数实现了对参数的解析和默认参数的设置,以及其它一些必要的工作,最后调用了更加核心的PyArray_Sort
函数。这一函数定义在numpy/core/src/multiarray/item_selection.c
中:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44NPY_NO_EXPORT int
PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND which)
{
PyArray_SortFunc *sort;
int n = PyArray_NDIM(op);
if (check_and_adjust_axis(&axis, n) < 0) {
return -1;
}
if (PyArray_FailUnlessWriteable(op, "sort array") < 0) {
return -1;
}
if (which < 0 || which >= NPY_NSORTS) {
PyErr_SetString(PyExc_ValueError, "not a valid sort kind");
return -1;
}
sort = PyArray_DESCR(op)->f->sort[which];
if (sort == NULL) {
if (PyArray_DESCR(op)->f->compare) {
switch (which) {
default:
case NPY_QUICKSORT:
sort = npy_quicksort;
break;
case NPY_HEAPSORT:
sort = npy_heapsort;
break;
case NPY_MERGESORT:
sort = npy_mergesort;
break;
}
}
else {
PyErr_SetString(PyExc_TypeError,
"type does not have compare function");
return -1;
}
}
return _new_sortlike(op, axis, sort, NULL, NULL, 0);
}
该函数在进行了参数检查后,进行了关键的sort = PyArray_DESCR(op)->f->sort[which];
,从上文所述的PyArray_Descr
找到了相应的排序函数。当无法找到时(当前的array没有定义特定的排序函数),默认将调用对Python对象进行排序的函数进行排序,前提是这个Python对象有比较函数。注意到,最后真正调用排序函数的又是另外一个包装函数_new_sortlike
,这是因为NumPy的partition
函数也需要相似的例程,其内容主要是对多坐标轴排序进行处理,不多做解释。简而言之,如果我们想要调用归并排序,调用过程为:1
2
3
4
5a.sort(kind='mergesort')
array_sort(methods.c)
PyArray_Sort(item_selection.c)
_new_sortlike(item_selection.c)
_@type@(mergesort.c.src) mergesort
我们在进行这些分析时,没有考虑进行argsort的情况,但其原理都是很类似的。
结语
这篇博文其实是以归并排序为例,对NumPy这个包底层的实现逻辑进行了一个分析。比如我们见识了一下在NumPy包中大量利用了的代码生成,见识了Python的C接口大概是什么样子,见识了一下简单的排序函数是如何被层层包装以达到最大程度的复用,也见识了NumPy在一些算法细节上的优秀处理。当然,这里的所谓“分析”还是比较粗糙和简略的,比如NumPy的分发过程很有意思(*.c.src
什么时候变成*.c
的?),我因为时间有限没有特别细看。路漫漫其修远兮,吾将上下而求索。