python - Indexing tensor with binary matrix in numpy -


i have tensor such a.shape = (32, 19, 2) , binary matrix b such b.shape = (32, 19). there one-line operation can perform matrix c, c.shape = (32, 19) , c(i,j) = a[i, j, b[i,j]]?

essentially, want use b indexing matrix, if b[i,j] = 1 take a[i,j,1] form c(i,j).

np.where rescue. it's same principle mtrw's answer:

in [344]: a=np.arange(4*3*2).reshape(4,3,2)  in [345]: b=np.zeros((4,3),dtype=int)  in [346]: b[[0,1,1,2,3],[0,0,1,2,2]]=1  in [347]: b out[347]:  array([[1, 0, 0],        [1, 1, 0],        [0, 0, 1],        [0, 0, 1]])  in [348]: np.where(b,a[:,:,1],a[:,:,0]) out[348]:  array([[ 1,  2,  4],        [ 7,  9, 10],        [12, 14, 17],        [18, 20, 23]]) 

np.choose can used if last dimension larger 2 (but smaller 32). (choose operates on list or 1st dimension, hence rollaxis.

in [360]: np.choose(b,np.rollaxis(a,2)) out[360]:  array([[ 1,  2,  4],        [ 7,  9, 10],        [12, 14, 17],        [18, 20, 23]]) 

b can used directly index. trick specify other dimensions in way broadcasts same shape.

in [373]: a[np.arange(a.shape[0])[:,none], np.arange(a.shape[1])[none,:], b] out[373]:  array([[ 1,  2,  4],        [ 7,  9, 10],        [12, 14, 17],        [18, 20, 23]]) 

this last approach can modified work when b not match 1st 2 dimensions of a.

np.ix_ may simplify indexing

i, j = np.ix_(np.arange(4),np.arange(3)) a[i, j, b] 

Comments

Popular posts from this blog

apache - PHP Soap issue while content length is larger -

asynchronous - Python asyncio task got bad yield -

javascript - Complete OpenIDConnect auth when requesting via Ajax -