这是numpy数组比较的问题

情况一:返回的是数组

import numpy as np
a = np.array([1, 2, 3])
b = np.array([1, 5, 6])
if a == b:
    pass

因为a==b的结果是[True False False]
解决方案是:

  • .any():只要有一个位置的元素TrueTrue
  • .all():每个位置的元素都TrueTrue
print((a == b).any())   # True
print((a == b).all())   # False

情况二:list(numpy数组)

import numpy as np

# 纯numpy没问题
a = np.array([1, 2])
b = [1, 2]
print(a == b)   # [ True  True ]

# 纯numpy没问题
c = np.array([[1, 2], [3, 4]])
d = [[1, 2], [3, 4]]
print(c == d)   # [[ True  True], [ True  True]]

# list(numpy)就不行
n1 = np.array([1, 2])
n2 = np.array([3, 4])
n_list_1 = [n1]
n_list_2 = [n1, n2]
print(n_list_1 == b)    # ValueError
print(n_list_2 == d)    # ValueError

解法:将list(numpy数组)转化为纯Numpy

print(np.array(n_list_1) == b)    # [ True  True ]
print(np.array(n_list_2) == d)    # [[ True  True], [ True  True]]
Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐