问题

>我有一个由arr定义的ndarray,它是一个n维立方体,每个维度的长度为m.

>我想通过沿着维度n = 0切片并将每个n-1-dim切片作为函数的输入来执行函数func.

这似乎适用于map(),但我找不到适合的numpy变体. np.vectorise似乎将n-1-tensor分成单个标量条目. apply_along_axis或apply_over_axes似乎也不合适.

我的问题是我需要传递任意函数作为输入,因此我没有看到einsum可行的解决方案.

>你知道使用np.asarray(map(func,arr))的最佳numpy替代方法吗?

我定义了一个示例数组,arr为4-dim立方体(或4-tensor):

m, n = 3, 4

arr = np.arange(m**n).reshape((m,)*n)

我定义了一个示例函数f,

def f(x):

"""makes it obvious how the np.ndarray is being passed into the function"""

try: # perform an op using x[0,0,0] which is expected to exist

i = x[0,0,0]

except:

print '\nno element x[0,0,0] in x: \n{}'.format(x)

return np.nan

return x-x+i

此函数的预期结果res将保持相同的形状,但将满足以下条件:

print all([(res[i] == i*m**(n-1)).all() for i in range(m)])

这适用于默认的map()函数,

res = np.asarray(map(f, a))

print all([(res[i] == i*m**(n-1)).all() for i in range(m)])

True

我希望np.vectorize以与map()相同的方式工作,但它在标量条目中起作用:

res = np.vectorize(f)(a)

no element x[0,0,0] in x:

0

...

最佳答案 鉴于arr是4d,你的fn适用于3d数组,

np.asarray(map(func, arr))

看起来非常合理.我会使用列表理解表单,但这是编程风格的问题

np.asarray([func(i) for i in arr])

因为我在arr的第一个维度迭代.实际上,它将arr视为3d数组的列表.然后它将结果列表重新组合成一个4d数组.

np.vectorize doc可以更明确地说明使用标量的函数.但是,是的,它将值传递为标量.请注意,np.vectorize没有提供传递迭代轴参数的规定.当你的函数从多个数组获取值时,它是最有用的

[func(a,b) for a,b in zip(arrA, arrB)]

它概括了拉链以便广播.但否则它是一个迭代的解决方案.它对你的功能的内容一无所知,因此无法加速其通话速度.

np.vectorize最终调用np.frompyfunc,这有点不那么通用,速度要快一些.但它也将标量传递给了func.

np.apply_along / over_ax(e / i)也迭代一个或多个轴.您可能会发现他们的代码具有指导性,但我同意他们不适用于此处.

映射方法的一个变体是分配结果数组和索引:

In [45]: res=np.zeros_like(arr,int)

In [46]: for i in range(arr.shape[0]):

...: res[i,...] = f(arr[i,...])

如果您需要在与第1个轴不同的轴上进行迭代,这可能会更容易.

你需要做自己的时间,看看哪个更快.

========================

使用就地修改在第一维上进行迭代的示例:

In [58]: arr.__array_interface__['data'] # data buffer address

Out[58]: (152720784, False)

In [59]: for i,a in enumerate(arr):

...: print(a.__array_interface__['data'])

...: a[0,0,:]=i

...:

(152720784, False) # address of the views (same buffer)

(152720892, False)

(152721000, False)

In [60]: arr

Out[60]:

array([[[[ 0, 0, 0],

[ 3, 4, 5],

[ 6, 7, 8]],

...

[[[ 1, 1, 1],

[30, 31, 32],

...

[[[ 2, 2, 2],

[57, 58, 59],

[60, 61, 62]],

...]]])

当我遍历一个数组时,我得到一个从公共数据缓冲区上的连续点开始的视图.如果我修改视图,如上所述甚至是[:] = …,我修改原始视图.我不需要写任何东西.但是不要使用a = ….,它打破了原始数组的链接.

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐