警惕NumPy切片视图(Slice View)中的“内存泄漏”陷阱

NumPy的视图(View)机制可以大大加快对数组进行切片和reshape的速度,同时节省内存。但View机制存在一个极难发现的陷阱,会在许多常见的应用场景下引起内存泄漏。

视图(View)机制

在Python中,对列表或者字符串进行切片必须进行拷贝:

1
2
3
4
5
6
7
8
9
>>> a = list(range(10))
>>> b = a[:5]
>>> b
[0, 1, 2, 3, 4]
>>> b[0] = 999
>>> b
[999, 1, 2, 3, 4]
>>> a
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

这意味着切片操作有着不菲的时间开销。在性能第一的科学计算库NumPy中,为了使大型数组的切片可以在常数时间内完成,也为了方便对大数组局部进行修改,引入了视图机制,即切片时不进行拷贝,返回一个数据指针仍然指向原数组,只是metadata发生了变化的视图:

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> a = np.arange(10)
>>> a.flags.owndata
True
>>> b = a[:5]
>>> b
array([0, 1, 2, 3, 4])
>>> b.flags.owndata
False
>>> b[0] = 999
>>> b
array([999, 1, 2, 3, 4])
>>> a
array([999, 1, 2, 3, 4, 5, 6, 7, 8, 9])

相当多的数值计算库实际上都使用了类似的做法。当然,根据官方文档,视图机制还有许多别的作用,不过本文关注的内存泄露问题只和切片有关。

切片视图导致的内存泄漏

以下这段代码代表了一个很常见的逻辑:通过某种操作构造一个大数组,然后从中取出并保留特定的一小部分。这一逻辑会引起极难发现的内存泄漏。

1
2
3
4
5
def foo():
a = np.random.rand(int(2e8))
b = a[:100]
return b
b = foo()

这段代码的最终效果是通过foo()将一个长度为100的数组赋值给b,看起来并不像是典型的内存泄漏代码。感兴趣的读者不妨自己在交互环境中试一试这段代码,然后查看内存占用。在本例中我特意创建了一个非常大的数组,含有\(2 \times 10^8\)个元素,应该会有超过1G的内存占用。依据直觉和常识,在创建a时有超过1G的内存占用是完全合理的,但是在执行完毕foo()后,a会被垃圾回收,大量内存会被释放掉,最终留下一个内存占用几乎可以忽略的b

1
2
3
4
5
>>> from sys import getsizeof
>>> getsizeof(b)
96
>>> b.nbytes
800

但实际操作发现在foo()执行结束后,进程的内存占用仍然居高不下,只有执行del b之后才会看到显著的内存占用下降。
根据观察到的现象,结合对视图机制的理解,我们不难推测,b中包含一个指向a的指针,导致a的引用计数在执行foo()后并没有归零进而被垃圾回收,而是一直保存在b中。事实上,我们只需要简单地访问b.base,就可以获得直觉上已经被垃圾回收的a对象:

1
2
>>> b.base.shape
(200000000, )

这一base属性(attribute)在NumPy内部实现上也非常简单直接,就是一个指向原数组的指针。对于一个ndarray的数据结构来说,必要的成员变量包括一个指向数据的指针,指向维度数组的指针以及维度数组的大小,再加上strides步长。NumPy源码在定义ndarray的数据结构时,首先声明了上述必须的成员变量,名为base的指针紧跟其后。代码中还包括大量对base作用的注释:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
typedef struct tagPyArrayObject_fields {
PyObject_HEAD
/* Pointer to the raw data buffer */
char *data;
/* The number of dimensions, also called 'ndim' */
int nd;
/* The size in each dimension, also called 'shape' */
npy_intp *dimensions;
/*
* Number of bytes to jump to get to the
* next element in each dimension
*/
npy_intp *strides;
/*
* This object is decref'd upon
* deletion of array. Except in the
* case of WRITEBACKIFCOPY which has
* special handling.
*
* For views it points to the original
* array, collapsed so no chains of
* views occur.
*
* For creation from buffer object it
* points to an object that should be
* decref'd on deletion
*
* For WRITEBACKIFCOPY flag this is an
* array to-be-updated upon calling
* PyArray_ResolveWritebackIfCopy
*/
PyObject *base;
/* Pointer to type structure */
PyArray_Descr *descr;
/* Flags describing array -- see below */
int flags;
/* For weak references */
PyObject *weakreflist;
} PyArrayObject_fields;

对于自己拥有数据的非视图数组而言,base指针被设为NULL,反映到Python中(通过descriptor)就是None:

1
2
>>> print(b.base.base)
None

至此,前面提到的内存泄漏的原因也就十分清晰了。尽管foo()返回的数组大小只有800个bytes(根据b.nbytes),但其实它包括了一个指向更大数组的指针base,而这个base的可能会非常大,实际上造成了内存的泄漏。

解决方案

对于用户来说,最简单(但并不容易!)的解决方案是在数组切片时视情况进行copy,使之数据成为自有的,切断和base的联系:

1
2
3
4
5
def foo():
a = np.random.rand(int(2e8))
b = a[:100].copy()
return b
b = foo()

对于NumPy来说,在一个如此常见的使用场景下存在内存泄漏的隐患不是合理的设计。一个最直接的“改进”是在NumPy的ndarray数据结构中加入该数组被多少其它数组作为base的字段,当这一数值和引用计数相等时令视图数组进行拷贝。这一“改进”会带来包括性能下降在内的一系列问题,甚至在存在多个视图的情况下可能同时导致CPU和内存效率下降。我目前很难想到在不影响效率的情况下解决这一问题的方案。

后记杂谈

我是在一个排查一个大型应用的异常内存占用时发现这一“内存泄漏”问题的。在尝试了许多memory profiler无果之后我自己开发了一个工具RememberMe来检查Python中对象的内存占用。RememberMe简单来说就是sys.getsizeof\(\times\)gc.get_referents,对Python对象内存占用的估计还是比较准确的。我利用这一工具不断缩小排查范围,最终得以确定是一组NumPy数组的内存占用远超预期,并追溯到引起这一异常行为的切片操作。Python最好的性能profiler应该是py-spy,内存profiler上实在没什么特别好的工具。
RememberMe翻译成中文可作“勿忘我”。不少人可能因为CoCo这部电影而对Remember Me有所印象,但我起这个名字是因为一部赛博朋克电子游戏

又及:经本人PR,在NumPy文档中已经添加了一些提醒。