@@ -309,27 +309,42 @@ def concat(arrays, axis=0, dtype=None):
309309 return result
310310
311311
312- class LArrayIterator (object ):
313- __slots__ = ('nextfunc' , 'axes' )
312+ if PY2 :
313+ class LArrayIterator (object ):
314+ __slots__ = ('next' ,)
314315
315- def __init__ (self , array ):
316- data_iter = iter (array .data )
317- self .nextfunc = data_iter .next if PY2 else data_iter .__next__
318- self .axes = array .axes [1 :]
316+ def __init__ (self , array ):
317+ data_iter = iter (array .data )
318+ next_data_func = data_iter .next
319+ res_axes = array .axes [1 :]
320+ # this case should not happen (handled by the fastpath in LArray.__iter__)
321+ assert len (res_axes ) > 0
319322
320- def __iter__ ( self ):
321- return self
323+ def next_func ( ):
324+ return LArray ( next_data_func (), res_axes )
322325
323- def __next__ (self ):
324- data = self .nextfunc ()
325- axes = self .axes
326- if len (axes ):
327- return LArray (data , axes )
328- else :
329- return data
326+ self .next = next_func
330327
331- # Python 2
332- next = __next__
328+ def __iter__ (self ):
329+ return self
330+ else :
331+ class LArrayIterator (object ):
332+ __slots__ = ('__next__' ,)
333+
334+ def __init__ (self , array ):
335+ data_iter = iter (array .data )
336+ next_data_func = data_iter .__next__
337+ res_axes = array .axes [1 :]
338+ # this case should not happen (handled by the fastpath in LArray.__iter__)
339+ assert len (res_axes ) > 0
340+
341+ def next_func ():
342+ return LArray (next_data_func (), res_axes )
343+
344+ self .__next__ = next_func
345+
346+ def __iter__ (self ):
347+ return self
333348
334349
335350# TODO: rename to LArrayIndexIndexer or something like that
@@ -355,14 +370,41 @@ def _translate_key(self, key):
355370 for axis_key , axis in zip (key , self .array .axes ))
356371
357372 def __getitem__ (self , key ):
358- return self .array [self ._translate_key (key )]
373+ ndim = self .array .ndim
374+ full_scalar_key = (
375+ (isinstance (key , (int , np .integer )) and ndim == 1 ) or
376+ (isinstance (key , tuple ) and len (key ) == ndim and all (isinstance (k , (int , np .integer )) for k in key ))
377+ )
378+ # fast path when the result is a scalar
379+ if full_scalar_key :
380+ return self .array .data [key ]
381+ else :
382+ return self .array [self ._translate_key (key )]
359383
360384 def __setitem__ (self , key , value ):
361- self .array [self ._translate_key (key )] = value
385+ array = self .array
386+ ndim = array .ndim
387+ full_scalar_key = (
388+ (isinstance (key , (int , np .integer )) and ndim == 1 ) or
389+ (isinstance (key , tuple ) and len (key ) == ndim and all (isinstance (k , (int , np .integer )) for k in key ))
390+ )
391+ # fast path when setting a single cell
392+ if full_scalar_key :
393+ array .data [key ] = value
394+ else :
395+ array [self ._translate_key (key )] = value
362396
363397 def __len__ (self ):
364398 return len (self .array )
365399
400+ def __iter__ (self ):
401+ array = self .array
402+ # fast path for 1D arrays (where we return scalars)
403+ if array .ndim <= 1 :
404+ return iter (array .data )
405+ else :
406+ return LArrayIterator (array )
407+
366408
367409class LArrayPointsIndexer (object ):
368410 __slots__ = ('array' ,)
@@ -2696,6 +2738,7 @@ def _group_aggregate(self, op, items, keepaxes=False, out=None, **kwargs):
26962738 arr = np .asarray (arr )
26972739 op (arr , axis = axis_idx , out = out , ** kwargs )
26982740 del arr
2741+
26992742 if killaxis :
27002743 assert group_idx [axis_idx ] == 0
27012744 res_data = res_data [idx ]
0 commit comments