Мне нужно суммировать элементы в одномерном массиве numpy
(ниже: data
) на основе другого массива с информацией о членстве в классе (labels
). Я использую numba
в приведенном ниже коде, чтобы ускорить его. Однако, если я не указал явно int()
в строке ret[int(find(labels, g))] += y
, я получаю сообщение об ошибке:
TypeError: unsupported array index type ?int64
Есть ли лучший обходной путь, чем явное приведение?
import numpy as np
from numba import jit
labels = np.array([45, 85, 99, 89, 45, 86, 348, 764])
n = int(1e3)
data = np.random.random(n)
groups = np.random.choice(a=labels, size=n, replace=True)
@jit(nopython=True)
def find(seq, value):
for ct, x in enumerate(seq):
if x == value:
return ct
@jit(nopython=True)
def subsumNumba(data, groups, labels):
ret = np.zeros(len(labels))
for y, g in zip(data, groups):
# not working without casting with int()
ret[int(find(labels, g))] += y
return ret
find
гарантированно найдет что-то в структуре моей проблемы. Теперь я просто возвращаю фиктивное целое число для теоретического случая, когда нет попадания, и это работает! 04.09.2016