我们先来看下np.split的实现方法:
@array_function_dispatch(_split_dispatcher)
def split(ary, indices_or_sections, axis=0):try:len(indices_or_sections)except TypeError:sections = indices_or_sectionsN = ary.shape[axis]if N % sections:raise ValueError('array split does not result in an equal division') from Nonereturn array_split(ary, indices_or_sections, axis)
当然有兴趣的可以继续看array_split的具体操作方法。
从split的定义可以看到,参数是(数组,数或数组,维度)返回值是列表,里面的每组元素是数组。看示例:
参数是整数
import numpy as npa=np.arange(10)
np.split(a,5)
#[array([0, 1]), array([2, 3]), array([4, 5]), array([6, 7]), array([8, 9])]
将0-9,均分成5份。如果不能被整除,比如是4,将出现错误:
ValueError: array split does not result in an equal division
参数是列表,按照里面每个值来分段
a=np.arange(10)
np.split(a,[4,8])
#[array([0, 1, 2, 3]), array([4, 5, 6, 7]), array([8, 9])]
变形成二维数组来切分
a=np.arange(40).reshape(8,5)
np.split(a,4)
'''
[array([[0, 1, 2, 3, 4],[5, 6, 7, 8, 9]]),array([[10, 11, 12, 13, 14],[15, 16, 17, 18, 19]]),array([[20, 21, 22, 23, 24],[25, 26, 27, 28, 29]]),array([[30, 31, 32, 33, 34],[35, 36, 37, 38, 39]])]
'''
axis=1的结果
a=np.arange(40).reshape(5,8)
np.split(a,4,axis=1)
'''
[array([[ 0, 1],[ 8, 9],[16, 17],[24, 25],[32, 33]]),array([[ 2, 3],[10, 11],[18, 19],[26, 27],[34, 35]]),array([[ 4, 5],[12, 13],[20, 21],[28, 29],[36, 37]]),array([[ 6, 7],[14, 15],[22, 23],[30, 31],[38, 39]])]
'''
接下来的nd.split的用法跟np.split虽然用法很像,还是存在一些区别需要注意。
还是贴下nd.split的方法:
def split(data=None, num_outputs=_Null, axis=_Null, squeeze_axis=_Null, out=None, name=None, **kwargs):return (0,)
跟np.split的区别就是必须指定axis,不然会报错。
from mxnet import nda=nd.arange(40).reshape(8,5)
nd.split(a,4,axis=0)'''
[[[0. 1. 2. 3. 4.][5. 6. 7. 8. 9.]]<NDArray 2x5 @cpu(0)>,[[10. 11. 12. 13. 14.][15. 16. 17. 18. 19.]]<NDArray 2x5 @cpu(0)>,[[20. 21. 22. 23. 24.][25. 26. 27. 28. 29.]]<NDArray 2x5 @cpu(0)>,[[30. 31. 32. 33. 34.][35. 36. 37. 38. 39.]]<NDArray 2x5 @cpu(0)>]
'''
三维的例子亦如是,如:切分第二维切4份
a=nd.arange(40).reshape(2,4,5)
nd.split(a,4,axis=1)
'''
[[[[ 0. 1. 2. 3. 4.]][[20. 21. 22. 23. 24.]]]<NDArray 2x1x5 @cpu(0)>,[[[ 5. 6. 7. 8. 9.]][[25. 26. 27. 28. 29.]]]<NDArray 2x1x5 @cpu(0)>,[[[10. 11. 12. 13. 14.]][[30. 31. 32. 33. 34.]]]<NDArray 2x1x5 @cpu(0)>,[[[15. 16. 17. 18. 19.]][[35. 36. 37. 38. 39.]]]<NDArray 2x1x5 @cpu(0)>]
'''
每个元素的形状是nd.split(a,4,axis=1)[1].shape #(2,1,5)
除了参数名称不一样,个数也不一样,比如squeeze_axis这个新增的参数,可以减掉一维。
a=nd.arange(40).reshape(2,4,5)
nd.split(a,2,axis=0,squeeze_axis=1)
'''
[[[ 0. 1. 2. 3. 4.][ 5. 6. 7. 8. 9.][10. 11. 12. 13. 14.][15. 16. 17. 18. 19.]]<NDArray 4x5 @cpu(0)>,[[20. 21. 22. 23. 24.][25. 26. 27. 28. 29.][30. 31. 32. 33. 34.][35. 36. 37. 38. 39.]]<NDArray 4x5 @cpu(0)>]
'''
看出有什么不同了吗?少了一维,本来里面每个元素是 <NDArray 1x4x5 @cpu(0)>,现在变为 <NDArray 4x5 @cpu(0)>
再看一例:
a=nd.arange(40).reshape(2,4,5)
nd.split(a,4,axis=1,squeeze_axis=1)
'''
[[[ 0. 1. 2. 3. 4.][20. 21. 22. 23. 24.]]<NDArray 2x5 @cpu(0)>,[[ 5. 6. 7. 8. 9.][25. 26. 27. 28. 29.]]<NDArray 2x5 @cpu(0)>,[[10. 11. 12. 13. 14.][30. 31. 32. 33. 34.]]<NDArray 2x5 @cpu(0)>,[[15. 16. 17. 18. 19.][35. 36. 37. 38. 39.]]<NDArray 2x5 @cpu(0)>]
'''
如果没有squeeze_axis=1这个参数,里面的元素形状是<NDArray 2x1x5 @cpu(0)>, 现在变为<NDArray 2x5 @cpu(0)>
所以这个其实就是将所在切分的维,有且仅有1,那么就减掉这个维度。这个其实是有意义的,毕竟属于没数据的占着空的维度,可以去掉。
本文发布于:2024-02-02 07:10:26,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/170682902742187.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |