# HG changeset patch
# User da-woods <dw-git@d-woods.co.uk>
# Date 1593509001 -3600
#      Tue Jun 30 10:23:21 2020 +0100
# Node ID 27337172d93448424f9c1b6a43e226deac76b4ba
# Parent  4a4a89917943e5fc6d55708db88d9175d73b94fe
Implement generic optimized loop iterator with indexing and type inference for memoryviews (GH-3617)

* Adds bytearray iteration since that was not previously optimised (because it allows changing length during iteration).
* Always set `entry.init` for memoryviewslice.

diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py
--- a/Cython/Compiler/ExprNodes.py
+++ b/Cython/Compiler/ExprNodes.py
@@ -3564,6 +3564,8 @@
                                bytearray_type, list_type, tuple_type):
                 # slicing these returns the same type
                 return base_type
+            elif base_type.is_memoryviewslice:
+                return base_type
             else:
                 # TODO: Handle buffers (hopefully without too much redundancy).
                 return py_object_type
@@ -3606,6 +3608,23 @@
                         index += base_type.size
                     if 0 <= index < base_type.size:
                         return base_type.components[index]
+            elif base_type.is_memoryviewslice:
+                if base_type.ndim == 0:
+                    pass  # probably an error, but definitely don't know what to do - return pyobject for now
+                if base_type.ndim == 1:
+                    return base_type.dtype
+                else:
+                    return PyrexTypes.MemoryViewSliceType(base_type.dtype, base_type.axes[1:])
+
+        if self.index.is_sequence_constructor and base_type.is_memoryviewslice:
+            inferred_type = base_type
+            for a in self.index.args:
+                if not inferred_type.is_memoryviewslice:
+                    break  # something's gone wrong
+                inferred_type = IndexNode(self.pos, base=ExprNode(self.base.pos, type=inferred_type),
+                                          index=a).infer_type(env)
+            else:
+                return inferred_type
 
         if base_type.is_cpp_class:
             class FakeOperand:
@@ -13466,6 +13485,9 @@
         # The arg is always already analysed
         return self
 
+    def may_be_none(self):
+        return self.arg.may_be_none()
+
     def coerce_to_boolean(self, env):
         self.arg = self.arg.coerce_to_boolean(env)
         if self.arg.is_simple():
diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py
--- a/Cython/Compiler/Optimize.py
+++ b/Cython/Compiler/Optimize.py
@@ -228,6 +228,12 @@
             return self._transform_bytes_iteration(node, iterable, reversed=reversed)
         if iterable.type is Builtin.unicode_type:
             return self._transform_unicode_iteration(node, iterable, reversed=reversed)
+        # in principle _transform_indexable_iteration would work on most of the above, and
+        # also tuple and list. However, it probably isn't quite as optimized
+        if iterable.type is Builtin.bytearray_type:
+            return self._transform_indexable_iteration(node, iterable, is_mutable=True, reversed=reversed)
+        if isinstance(iterable, ExprNodes.CoerceToPyTypeNode) and iterable.arg.type.is_memoryviewslice:
+            return self._transform_indexable_iteration(node, iterable.arg, is_mutable=False, reversed=reversed)
 
         # the rest is based on function calls
         if not isinstance(iterable, ExprNodes.SimpleCallNode):
@@ -333,6 +339,92 @@
             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
             ])
 
+    def _transform_indexable_iteration(self, node, slice_node, is_mutable, reversed=False):
+        """In principle can handle any iterable that Cython has a len() for and knows how to index"""
+        unpack_temp_node = UtilNodes.LetRefNode(
+            slice_node.as_none_safe_node("'NoneType' is not iterable"),
+            may_hold_none=False, is_temp=True
+            )
+
+        start_node = ExprNodes.IntNode(
+            node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type)
+        def make_length_call():
+            # helper function since we need to create this node for a couple of places
+            builtin_len = ExprNodes.NameNode(node.pos, name="len",
+                                             entry=Builtin.builtin_scope.lookup("len"))
+            return ExprNodes.SimpleCallNode(node.pos,
+                                    function=builtin_len,
+                                    args=[unpack_temp_node]
+                                    )
+        length_temp = UtilNodes.LetRefNode(make_length_call(), type=PyrexTypes.c_py_ssize_t_type, is_temp=True)
+        end_node = length_temp
+
+        if reversed:
+            relation1, relation2 = '>', '>='
+            start_node, end_node = end_node, start_node
+        else:
+            relation1, relation2 = '<=', '<'
+
+        counter_ref = UtilNodes.LetRefNode(pos=node.pos, type=PyrexTypes.c_py_ssize_t_type)
+
+        target_value = ExprNodes.IndexNode(slice_node.pos, base=unpack_temp_node,
+                                           index=counter_ref)
+
+        target_assign = Nodes.SingleAssignmentNode(
+            pos = node.target.pos,
+            lhs = node.target,
+            rhs = target_value)
+
+        # analyse with boundscheck and wraparound
+        # off (because we're confident we know the size)
+        env = self.current_env()
+        new_directives = Options.copy_inherited_directives(env.directives, boundscheck=False, wraparound=False)
+        target_assign = Nodes.CompilerDirectivesNode(
+            target_assign.pos,
+            directives=new_directives,
+            body=target_assign,
+        )
+
+        body = Nodes.StatListNode(
+            node.pos,
+            stats = [target_assign])  # exclude node.body for now to not reanalyse it
+        if is_mutable:
+            # We need to be slightly careful here that we are actually modifying the loop
+            # bounds and not a temp copy of it. Setting is_temp=True on length_temp seems
+            # to ensure this.
+            # If this starts to fail then we could insert an "if out_of_bounds: break" instead
+            loop_length_reassign = Nodes.SingleAssignmentNode(node.pos,
+                                                        lhs = length_temp,
+                                                        rhs = make_length_call())
+            body.stats.append(loop_length_reassign)
+
+        loop_node = Nodes.ForFromStatNode(
+            node.pos,
+            bound1=start_node, relation1=relation1,
+            target=counter_ref,
+            relation2=relation2, bound2=end_node,
+            step=None, body=body,
+            else_clause=node.else_clause,
+            from_range=True)
+
+        ret = UtilNodes.LetNode(
+                    unpack_temp_node,
+                    UtilNodes.LetNode(
+                        length_temp,
+                        # TempResultFromStatNode provides the framework where the "counter_ref"
+                        # temp is set up and can be assigned to. However, we don't need the
+                        # result it returns so wrap it in an ExprStatNode.
+                        Nodes.ExprStatNode(node.pos,
+                            expr=UtilNodes.TempResultFromStatNode(
+                                    counter_ref,
+                                    loop_node
+                            )
+                        )
+                    )
+                ).analyse_expressions(env)
+        body.stats.insert(1, node.body)
+        return ret
+
     def _transform_bytes_iteration(self, node, slice_node, reversed=False):
         target_type = node.target.type
         if not target_type.is_int and target_type is not Builtin.bytes_type:
diff --git a/Cython/Compiler/Options.py b/Cython/Compiler/Options.py
--- a/Cython/Compiler/Options.py
+++ b/Cython/Compiler/Options.py
@@ -166,6 +166,16 @@
                 _directive_defaults[old_option.directive_name] = value
     return _directive_defaults
 
+def copy_inherited_directives(outer_directives, **new_directives):
+    # A few directives are not copied downwards and this function removes them.
+    # For example, test_assert_path_exists and test_fail_if_path_exists should not be inherited
+    #  otherwise they can produce very misleading test failures
+    new_directives_out = dict(outer_directives)
+    for name in ('test_assert_path_exists', 'test_fail_if_path_exists'):
+        new_directives_out.pop(name, None)
+    new_directives_out.update(new_directives)
+    return new_directives_out
+
 # Declare compiler directives
 _directive_defaults = {
     'binding': True,  # was False before 3.0
diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py
--- a/Cython/Compiler/ParseTreeTransforms.py
+++ b/Cython/Compiler/ParseTreeTransforms.py
@@ -992,12 +992,7 @@
             return self.visit_Node(node)
 
         old_directives = self.directives
-        new_directives = dict(old_directives)
-        # test_assert_path_exists and test_fail_if_path_exists should not be inherited
-        # otherwise they can produce very misleading test failures
-        new_directives.pop('test_assert_path_exists', None)
-        new_directives.pop('test_fail_if_path_exists', None)
-        new_directives.update(directives)
+        new_directives = Options.copy_inherited_directives(old_directives, **directives)
 
         if new_directives == old_directives:
             return self.visit_Node(node)
diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py
--- a/Cython/Compiler/PyrexTypes.py
+++ b/Cython/Compiler/PyrexTypes.py
@@ -672,6 +672,10 @@
         else:
             return False
 
+    def __ne__(self, other):
+        # TODO drop when Python2 is dropped
+        return not (self == other)
+
     def same_as_resolved_type(self, other_type):
         return ((other_type.is_memoryviewslice and
             #self.writable_needed == other_type.writable_needed and  # FIXME: should be only uni-directional
@@ -2516,6 +2520,7 @@
         if self.is_string:
             assert isinstance(value, str)
             return '"%s"' % StringEncoding.escape_byte_string(value)
+        return str(value)
 
 
 class CArrayType(CPointerBaseType):
diff --git a/Cython/Compiler/TypeInference.py b/Cython/Compiler/TypeInference.py
--- a/Cython/Compiler/TypeInference.py
+++ b/Cython/Compiler/TypeInference.py
@@ -140,7 +140,6 @@
                                                      '+',
                                                      sequence.args[0],
                                                      sequence.args[2]))
-
         if not is_special:
             # A for-loop basically translates to subsequent calls to
             # __getitem__(), so using an IndexNode here allows us to
@@ -360,9 +359,11 @@
     applies to nested scopes in top-down order.
     """
     def set_entry_type(self, entry, entry_type):
-        entry.type = entry_type
         for e in entry.all_entries():
             e.type = entry_type
+            if e.type.is_memoryviewslice:
+                # memoryview slices crash if they don't get initialized
+                e.init = e.type.default_value
 
     def infer_types(self, scope):
         enabled = scope.directives['infer_types']
@@ -577,6 +578,8 @@
         # used, won't arise in pure Python, and there shouldn't be side
         # effects, so I'm declaring this safe.
         return result_type
+    elif result_type.is_memoryviewslice:
+        return result_type
     # TODO: double complex should be OK as well, but we need
     # to make sure everything is supported.
     elif (result_type.is_int or result_type.is_enum) and not might_overflow:
diff --git a/Cython/Compiler/UtilNodes.py b/Cython/Compiler/UtilNodes.py
--- a/Cython/Compiler/UtilNodes.py
+++ b/Cython/Compiler/UtilNodes.py
@@ -360,3 +360,6 @@
     def generate_result_code(self, code):
         self.result_ref.result_code = self.result()
         self.body.generate_execution_code(code)
+
+    def generate_function_definitions(self, env, code):
+        self.body.generate_function_definitions(env, code)
diff --git a/tests/memoryview/memoryview.pyx b/tests/memoryview/memoryview.pyx
--- a/tests/memoryview/memoryview.pyx
+++ b/tests/memoryview/memoryview.pyx
@@ -247,7 +247,7 @@
     >>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="ccqii"))
     [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)]
     """
-    buf = mslice
+    cdef object buf = mslice
     print sorted([(k, int(v)) for k, v in buf[0].items()])
 
 def nested_struct(NestedStruct[:] mslice):
@@ -259,7 +259,7 @@
     >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="T{ii}T{2i}i"))
     1 2 3 4 5
     """
-    buf = mslice
+    cdef object buf = mslice
     d = buf[0]
     print d['x']['a'], d['x']['b'], d['y']['a'], d['y']['b'], d['z']
 
@@ -275,7 +275,7 @@
     1 2
 
     """
-    buf = mslice
+    cdef object buf = mslice
     print buf[0]['a'], buf[0]['b']
 
 def nested_packed_struct(NestedPackedStruct[:] mslice):
@@ -289,7 +289,7 @@
     >>> nested_packed_struct(NestedPackedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="^c@i^ci@i"))
     1 2 3 4 5
     """
-    buf = mslice
+    cdef object buf = mslice
     d = buf[0]
     print d['a'], d['b'], d['sub']['a'], d['sub']['b'], d['c']
 
@@ -299,7 +299,7 @@
     >>> complex_dtype(LongComplexMockBuffer(None, [(0, -1)]))
     -1j
     """
-    buf = mslice
+    cdef object buf = mslice
     print buf[0]
 
 def complex_inplace(long double complex[:] mslice):
@@ -307,7 +307,7 @@
     >>> complex_inplace(LongComplexMockBuffer(None, [(0, -1)]))
     (1+1j)
     """
-    buf = mslice
+    cdef object buf = mslice
     buf[0] = buf[0] + 1 + 2j
     print buf[0]
 
@@ -318,7 +318,7 @@
     >>> complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)]))
     0.0 -1.0
     """
-    buf = mslice
+    cdef object buf = mslice
     print buf[0]['real'], buf[0]['imag']
 
 #
@@ -356,7 +356,7 @@
         ...
     IndexError: Out of bounds on buffer access (axis 1)
     """
-    buf = mslice
+    cdef object buf = mslice
     return buf[i, j]
 
 def set_int_2d(int[:, :] mslice, int i, int j, int value):
@@ -409,11 +409,48 @@
     IndexError: Out of bounds on buffer access (axis 1)
 
     """
-    buf = mslice
+    cdef object buf = mslice
     buf[i, j] = value
 
 
 #
+# auto type inference
+# (note that for most numeric types "might_overflow" stops the type inference from working well)
+#
+def type_infer(double[:, :] arg):
+    """
+    >>> type_infer(DoubleMockBuffer(None, range(6), (2,3)))
+    double
+    double[:]
+    double[:]
+    double[:, :]
+    """
+    a = arg[0,0]
+    print(cython.typeof(a))
+    b = arg[0]
+    print(cython.typeof(b))
+    c = arg[0,:]
+    print(cython.typeof(c))
+    d = arg[:,:]
+    print(cython.typeof(d))
+
+#
+# Loop optimization
+#
+@cython.test_fail_if_path_exists("//CoerceToPyTypeNode")
+def memview_iter(double[:, :] arg):
+    """
+    memview_iter(DoubleMockBuffer("C", range(6), (2,3)))
+    True
+    """
+    cdef double total = 0
+    for mview1d in arg:
+        for val in mview1d:
+            total += val
+    if total == 15:
+        return True
+
+#
 # Test all kinds of indexing and flags
 #
 
@@ -426,7 +463,7 @@
     >>> [str(x) for x in R.received_flags] # Py2/3
     ['FORMAT', 'ND', 'STRIDES', 'WRITABLE']
     """
-    buf = mslice
+    cdef object buf = mslice
     buf[2, 2, 1] = 23
 
 def strided(int[:] mslice):
@@ -441,7 +478,7 @@
     >>> A.release_ok
     True
     """
-    buf = mslice
+    cdef object buf = mslice
     return buf[2]
 
 def c_contig(int[::1] mslice):
@@ -450,7 +487,7 @@
     >>> c_contig(A)
     2
     """
-    buf = mslice
+    cdef object buf = mslice
     return buf[2]
 
 def c_contig_2d(int[:, ::1] mslice):
@@ -461,7 +498,7 @@
     >>> c_contig_2d(A)
     7
     """
-    buf = mslice
+    cdef object buf = mslice
     return buf[1, 3]
 
 def f_contig(int[::1, :] mslice):
@@ -470,7 +507,7 @@
     >>> f_contig(A)
     2
     """
-    buf = mslice
+    cdef object buf = mslice
     return buf[0, 1]
 
 def f_contig_2d(int[::1, :] mslice):
@@ -481,7 +518,7 @@
     >>> f_contig_2d(A)
     7
     """
-    buf = mslice
+    cdef object buf = mslice
     return buf[3, 1]
 
 def generic(int[::view.generic, ::view.generic] mslice1,
@@ -552,7 +589,7 @@
        ...
     ValueError: Buffer dtype mismatch, expected 'td_cy_int' but got 'short'
     """
-    buf = mslice
+    cdef object buf = mslice
     cdef int i
     for i in range(shape[0]):
         print buf[i],
@@ -567,7 +604,7 @@
        ...
     ValueError: Buffer dtype mismatch, expected 'td_h_short' but got 'int'
     """
-    buf = mslice
+    cdef object buf = mslice
     cdef int i
     for i in range(shape[0]):
         print buf[i],
@@ -582,7 +619,7 @@
        ...
     ValueError: Buffer dtype mismatch, expected 'td_h_cy_short' but got 'int'
     """
-    buf = mslice
+    cdef object buf = mslice
     cdef int i
     for i in range(shape[0]):
         print buf[i],
@@ -597,7 +634,7 @@
        ...
     ValueError: Buffer dtype mismatch, expected 'td_h_ushort' but got 'short'
     """
-    buf = mslice
+    cdef object buf = mslice
     cdef int i
     for i in range(shape[0]):
         print buf[i],
@@ -612,7 +649,7 @@
        ...
     ValueError: Buffer dtype mismatch, expected 'td_h_double' but got 'float'
     """
-    buf = mslice
+    cdef object buf = mslice
     cdef int i
     for i in range(shape[0]):
         print buf[i],
@@ -649,7 +686,7 @@
     {4: 23} 2
     [34, 3] 2
     """
-    buf = mslice
+    cdef object buf = mslice
     cdef int i
     for i in range(shape[0]):
         print repr(buf[i]), (<PyObject*>buf[i]).ob_refcnt
@@ -670,7 +707,7 @@
     (2, 3)
     >>> decref(b)
     """
-    buf = mslice
+    cdef object buf = mslice
     buf[idx] = obj
 
 def assign_temporary_to_object(object[:] mslice):
@@ -697,7 +734,7 @@
     >>> assign_to_object(A, 1, a)
     >>> decref(a)
     """
-    buf = mslice
+    cdef object buf = mslice
     buf[1] = {3-2: 2+(2*4)-2}
 
 
@@ -745,7 +782,7 @@
 
     """
     cdef int[::view.generic, ::view.generic, :] _a = arg
-    a = _a
+    cdef object a = _a
     b = a[2:8:2, -4:1:-1, 1:3]
 
     print b.shape
@@ -828,7 +865,7 @@
     released A
     """
     cdef int[:, :, :] _a = arg
-    a = _a
+    cdef object a = _a
     b = a[2:8:2, -4:1:-1, 1:3]
 
     print b.shape
@@ -856,7 +893,7 @@
     released A
     """
     cdef int[:, :, :] _a = arg
-    a = _a
+    cdef object a = _a
     b = a[-5:, 1, 1::2]
     c = b[4:1:-1, ::-1]
     d = c[2, 1:2]
diff --git a/tests/memoryview/memslice.pyx b/tests/memoryview/memslice.pyx
--- a/tests/memoryview/memslice.pyx
+++ b/tests/memoryview/memslice.pyx
@@ -1525,7 +1525,7 @@
     All dimensions preceding dimension 1 must be indexed and not sliced
     """
     cdef int[:, ::view.indirect, :] a = TestIndexSlicingDirectIndirectDims()
-    a_obj = a
+    cdef object a_obj = a
 
     print a[1][2][3]
     print a[1, 2, 3]
diff --git a/tests/memoryview/numpy_memoryview.pyx b/tests/memoryview/numpy_memoryview.pyx
--- a/tests/memoryview/numpy_memoryview.pyx
+++ b/tests/memoryview/numpy_memoryview.pyx
@@ -186,7 +186,7 @@
     numpy_obj = np.arange(4 * 3, dtype=np.int32).reshape(4, 3)
 
     a = numpy_obj
-    a_obj = a
+    cdef object a_obj = a
 
     cdef dtype_t[:, :] b = a.T
     print a.T.shape[0], a.T.shape[1]
@@ -244,7 +244,7 @@
     >>> test_copy_and_contig_attributes(a)
     """
     cdef np.int32_t[:, :] mslice = a
-    m = mslice
+    cdef object m = mslice  #  object copy
 
     # Test object copy attributes
     assert np.all(a == np.array(m.copy()))
diff --git a/tests/run/bytearray_iter.py b/tests/run/bytearray_iter.py
new file mode 100644
--- /dev/null
+++ b/tests/run/bytearray_iter.py
@@ -0,0 +1,90 @@
+# mode: run
+# tag: pure3, pure2
+
+import cython
+
+@cython.test_assert_path_exists("//ForFromStatNode")
+@cython.test_fail_if_path_exists("//ForInStatNode")
+@cython.locals(x=bytearray)
+def basic_bytearray_iter(x):
+    """
+    >>> basic_bytearray_iter(bytearray(b"hello"))
+    h
+    e
+    l
+    l
+    o
+    """
+    for a in x:
+        print(chr(a))
+
+@cython.test_assert_path_exists("//ForFromStatNode")
+@cython.test_fail_if_path_exists("//ForInStatNode")
+@cython.locals(x=bytearray)
+def reversed_bytearray_iter(x):
+    """
+    >>> reversed_bytearray_iter(bytearray(b"hello"))
+    o
+    l
+    l
+    e
+    h
+    """
+    for a in reversed(x):
+        print(chr(a))
+
+@cython.test_assert_path_exists("//ForFromStatNode")
+@cython.test_fail_if_path_exists("//ForInStatNode")
+@cython.locals(x=bytearray)
+def modifying_bytearray_iter1(x):
+    """
+    >>> modifying_bytearray_iter1(bytearray(b"abcdef"))
+    a
+    b
+    c
+    3
+    """
+    count = 0
+    for a in x:
+        print(chr(a))
+        del x[-1]
+        count += 1
+    print(count)
+
+@cython.test_assert_path_exists("//ForFromStatNode")
+@cython.test_fail_if_path_exists("//ForInStatNode")
+@cython.locals(x=bytearray)
+def modifying_bytearray_iter2(x):
+    """
+    >>> modifying_bytearray_iter2(bytearray(b"abcdef"))
+    a
+    c
+    e
+    3
+    """
+    count = 0
+    for a in x:
+        print(chr(a))
+        del x[0]
+        count += 1
+    print(count)
+
+@cython.test_assert_path_exists("//ForFromStatNode")
+@cython.test_fail_if_path_exists("//ForInStatNode")
+@cython.locals(x=bytearray)
+def modifying_reversed_bytearray_iter(x):
+    """
+    NOTE - I'm not 100% sure how well-defined this behaviour is in Python.
+    However, for the moment Python and Cython seem to do the same thing.
+    Testing that it doesn't crash is probably more important than the exact output!
+    >>> modifying_reversed_bytearray_iter(bytearray(b"abcdef"))
+    f
+    f
+    f
+    f
+    f
+    f
+    """
+    for a in reversed(x):
+        print(chr(a))
+        del x[0]