跳转至

NumPy分割数组

在 NumPy 中,可以使用多种方法将数组分割成多个子数组。这些方法包括 np.splitnp.array_splitnp.hsplitnp.vsplitnp.dsplit。以下是这些方法的详细介绍及示例:

np.split

np.split 函数用于将数组沿指定轴平均分割为多个子数组。

import numpy as np

# 一维数组分割
arr = np.array([1, 2, 3, 4, 5, 6])
result = np.split(arr, 3)
print(result)
# 输出: [array([1, 2]), array([3, 4]), array([5, 6])]

# 二维数组分割
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
result = np.split(arr, 2, axis=0)
print(result)
# 输出:
# [array([[1, 2, 3],
#        [4, 5, 6]]), 
#  array([[ 7,  8,  9],
#        [10, 11, 12]])]

result = np.split(arr, 3, axis=1)
print(result)
# 输出:
# [array([[ 1],
#        [ 4],
#        [ 7],
#        [10]]), 
#  array([[ 2],
#        [ 5],
#        [ 8],
#        [11]]), 
#  array([[ 3],
#        [ 6],
#        [ 9],
#        [12]])]

np.array_split

np.array_split 函数类似于 np.split,但允许不等分割,如果不能平均分割数组,则最后一个子数组会包含剩余的元素。

# 一维数组不等分割
arr = np.array([1, 2, 3, 4, 5, 6, 7])
result = np.array_split(arr, 3)
print(result)
# 输出: [array([1, 2, 3]), array([4, 5]), array([6, 7])]

# 二维数组不等分割
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
result = np.array_split(arr, 3, axis=0)
print(result)
# 输出:
# [array([[1, 2, 3],
#        [4, 5, 6]]), 
#  array([[7, 8, 9]]), 
#  array([[10, 11, 12]])]

np.hsplit

np.hsplit 函数用于将数组水平分割(按列分割)。等价于 np.split 函数中的 axis=1

arr = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
result = np.hsplit(arr, 2)
print(result)
# 输出:
# [array([[ 1,  2],
#        [ 5,  6],
#        [ 9, 10]]), 
#  array([[ 3,  4],
#        [ 7,  8],
#        [11, 12]])]

result = np.hsplit(arr, [1, 3])
print(result)
# 输出:
# [array([[ 1],
#        [ 5],
#        [ 9]]), 
#  array([[ 2,  3],
#        [ 6,  7],
#        [10, 11]]), 
#  array([[ 4],
#        [ 8],
#        [12]])]

np.vsplit

np.vsplit 函数用于将数组垂直分割(按行分割)。等价于 np.split 函数中的 axis=0

arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
result = np.vsplit(arr, 2)
print(result)
# 输出:
# [array([[1, 2, 3],
#        [4, 5, 6]]), 
#  array([[ 7,  8,  9],
#        [10, 11, 12]])]

result = np.vsplit(arr, [1, 3])
print(result)
# 输出:
# [array([[1, 2, 3]]), 
#  array([[4, 5, 6],
#        [7, 8, 9]]), 
#  array([[10, 11, 12]])]

np.dsplit

np.dsplit 函数用于将数组沿深度方向分割(第三轴)。

arr = np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
result = np.dsplit(arr, 3)
print(result)
# 输出:
# [array([[[ 1],
#         [ 4]],
#        [[ 7],
#         [10]]]), 
#  array([[[ 2],
#         [ 5]],
#        [[ 8],
#         [11]]]), 
#  array([[[ 3],
#         [ 6]],
#        [[ 9],
#         [12]]])]

通过这些方法,您可以灵活地分割 NumPy 数组以满足不同的数据处理需求。

评论