#
# Copyright (c) 2017 Intel Corporation
# SPDX-License-Identifier: BSD-2-Clause
#


import sys
import numpy as np
import ast
import inspect
import operator
import types as pytypes
from contextlib import contextmanager
from copy import deepcopy

import numba
from numba import njit, stencil
from numba.core.utils import PYVERSION
from numba.core import types, registry
from numba.core.compiler import compile_extra, Flags
from numba.core.cpu import ParallelOptions
from numba.tests.support import tag, skip_parfors_unsupported, _32bit
from numba.core.errors import LoweringError, TypingError, NumbaValueError
import unittest


skip_unsupported = skip_parfors_unsupported


@stencil
def stencil1_kernel(a):
    return 0.25 * (a[0, 1] + a[1, 0] + a[0, -1] + a[-1, 0])


@stencil(neighborhood=((-5, 0), ))
def stencil2_kernel(a):
    cum = a[-5]
    for i in range(-4, 1):
        cum += a[i]
    return 0.3 * cum


@stencil(cval=1.0)
def stencil3_kernel(a):
    return 0.25 * a[-2, 2]


@stencil
def stencil_multiple_input_kernel(a, b):
    return 0.25 * (a[0, 1] + a[1, 0] + a[0, -1] + a[-1, 0] +
                   b[0, 1] + b[1, 0] + b[0, -1] + b[-1, 0])


@stencil
def stencil_multiple_input_kernel_var(a, b, w):
    return w * (a[0, 1] + a[1, 0] + a[0, -1] + a[-1, 0] +
                b[0, 1] + b[1, 0] + b[0, -1] + b[-1, 0])


@stencil
def stencil_multiple_input_mixed_types_2d(a, b, f):
    return a[0, 0] if f[0, 0] else b[0, 0]


@stencil(standard_indexing=("b",))
def stencil_with_standard_indexing_1d(a, b):
    return a[-1] * b[0] + a[0] * b[1]


@stencil(standard_indexing=("b",))
def stencil_with_standard_indexing_2d(a, b):
    return (a[0, 1] * b[0, 1] + a[1, 0] * b[1, 0]
            + a[0, -1] * b[0, -1] + a[-1, 0] * b[-1, 0])


@njit
def addone_njit(a):
    return a + 1


if not _32bit: # prevent compilation on unsupported 32bit targets
    @njit(parallel=True)
    def addone_pjit(a):
        return a + 1


@unittest.skipIf(PYVERSION != (3, 7), "Run under 3.7 only, AST unstable")
class TestStencilBase(unittest.TestCase):

    _numba_parallel_test_ = False

    def __init__(self, *args):
        # flags for njit()
        self.cflags = Flags()
        self.cflags.nrt = True

        super(TestStencilBase, self).__init__(*args)

    def _compile_this(self, func, sig, flags):
        return compile_extra(registry.cpu_target.typing_context,
                             registry.cpu_target.target_context, func, sig,
                             None, flags, {})

    def compile_parallel(self, func, sig, **kws):
        flags = Flags()
        flags.nrt = True
        options = True if not kws else kws
        flags.auto_parallel=ParallelOptions(options)
        return self._compile_this(func, sig, flags)

    def compile_njit(self, func, sig):
        return self._compile_this(func, sig, flags=self.cflags)

    def compile_all(self, pyfunc, *args, **kwargs):
        sig = tuple([numba.typeof(x) for x in args])
        # compile with parallel=True
        cpfunc = self.compile_parallel(pyfunc, sig)
        # compile a standard njit of the original function
        cfunc = self.compile_njit(pyfunc, sig)
        return cfunc, cpfunc

    def check(self, no_stencil_func, pyfunc, *args):
        cfunc, cpfunc = self.compile_all(pyfunc, *args)
        # results without stencil macro
        expected = no_stencil_func(*args)
        # python result
        py_output = pyfunc(*args)

        # njit result
        njit_output = cfunc.entry_point(*args)

        # parfor result
        parfor_output = cpfunc.entry_point(*args)

        np.testing.assert_almost_equal(py_output, expected, decimal=3)
        np.testing.assert_almost_equal(njit_output, expected, decimal=3)
        np.testing.assert_almost_equal(parfor_output, expected, decimal=3)

        # make sure parfor set up scheduling
        self.assertIn('@do_scheduling', cpfunc.library.get_llvm_str())


class TestStencil(TestStencilBase):

    def __init__(self, *args, **kwargs):
        super(TestStencil, self).__init__(*args, **kwargs)

    @skip_unsupported
    def test_stencil1(self):
        """Tests whether the optional out argument to stencil calls works.
        """
        def test_with_out(n):
            A = np.arange(n**2).reshape((n, n))
            B = np.zeros(n**2).reshape((n, n))
            B = stencil1_kernel(A, out=B)
            return B

        def test_without_out(n):
            A = np.arange(n**2).reshape((n, n))
            B = stencil1_kernel(A)
            return B

        def test_impl_seq(n):
            A = np.arange(n**2).reshape((n, n))
            B = np.zeros(n**2).reshape((n, n))
            for i in range(1, n - 1):
                for j in range(1, n - 1):
                    B[i, j] = 0.25 * (A[i, j + 1] +
                                      A[i + 1, j] + A[i, j - 1] + A[i - 1, j])
            return B

        n = 100
        self.check(test_impl_seq, test_with_out, n)
        self.check(test_impl_seq, test_without_out, n)

    @skip_unsupported
    def test_stencil2(self):
        """Tests whether the optional neighborhood argument to the stencil
        decorate works.
        """
        def test_seq(n):
            A = np.arange(n)
            B = stencil2_kernel(A)
            return B

        def test_impl_seq(n):
            A = np.arange(n)
            B = np.zeros(n)
            for i in range(5, len(A)):
                B[i] = 0.3 * sum(A[i - 5:i + 1])
            return B

        n = 100
        self.check(test_impl_seq, test_seq, n)
        # variable length neighborhood in numba.stencil call
        # only supported in parallel path

        def test_seq(n, w):
            A = np.arange(n)

            def stencil2_kernel(a, w):
                cum = a[-w]
                for i in range(-w + 1, w + 1):
                    cum += a[i]
                return 0.3 * cum
            B = numba.stencil(stencil2_kernel, neighborhood=((-w, w), ))(A, w)
            return B

        def test_impl_seq(n, w):
            A = np.arange(n)
            B = np.zeros(n)
            for i in range(w, len(A) - w):
                B[i] = 0.3 * sum(A[i - w:i + w + 1])
            return B
        n = 100
        w = 5
        cpfunc = self.compile_parallel(test_seq, (types.intp, types.intp))
        expected = test_impl_seq(n, w)
        # parfor result
        parfor_output = cpfunc.entry_point(n, w)
        np.testing.assert_almost_equal(parfor_output, expected, decimal=3)
        self.assertIn('@do_scheduling', cpfunc.library.get_llvm_str())
        # test index_offsets

        def test_seq(n, w, offset):
            A = np.arange(n)

            def stencil2_kernel(a, w):
                cum = a[-w + 1]
                for i in range(-w + 1, w + 1):
                    cum += a[i + 1]
                return 0.3 * cum
            B = numba.stencil(stencil2_kernel, neighborhood=((-w, w), ),
                              index_offsets=(-offset, ))(A, w)
            return B

        offset = 1
        cpfunc = self.compile_parallel(test_seq, (types.intp, types.intp,
                                                  types.intp))
        parfor_output = cpfunc.entry_point(n, w, offset)
        np.testing.assert_almost_equal(parfor_output, expected, decimal=3)
        self.assertIn('@do_scheduling', cpfunc.library.get_llvm_str())
        # test slice in kernel

        def test_seq(n, w, offset):
            A = np.arange(n)

            def stencil2_kernel(a, w):
                return 0.3 * np.sum(a[-w + 1:w + 2])
            B = numba.stencil(stencil2_kernel, neighborhood=((-w, w), ),
                              index_offsets=(-offset, ))(A, w)
            return B

        offset = 1
        cpfunc = self.compile_parallel(test_seq, (types.intp, types.intp,
                                                  types.intp))
        parfor_output = cpfunc.entry_point(n, w, offset)
        np.testing.assert_almost_equal(parfor_output, expected, decimal=3)
        self.assertIn('@do_scheduling', cpfunc.library.get_llvm_str())

    @skip_unsupported
    def test_stencil3(self):
        """Tests whether a non-zero optional cval argument to the stencil
        decorator works.  Also tests integer result type.
        """
        def test_seq(n):
            A = np.arange(n**2).reshape((n, n))
            B = stencil3_kernel(A)
            return B

        test_njit = njit(test_seq)
        test_par = njit(test_seq, parallel=True)

        n = 5
        seq_res = test_seq(n)
        njit_res = test_njit(n)
        par_res = test_par(n)

        self.assertTrue(seq_res[0, 0] == 1.0 and seq_res[4, 4] == 1.0)
        self.assertTrue(njit_res[0, 0] == 1.0 and njit_res[4, 4] == 1.0)
        self.assertTrue(par_res[0, 0] == 1.0 and par_res[4, 4] == 1.0)

    @skip_unsupported
    def test_stencil_standard_indexing_1d(self):
        """Tests standard indexing with a 1d array.
        """
        def test_seq(n):
            A = np.arange(n)
            B = [3.0, 7.0]
            C = stencil_with_standard_indexing_1d(A, B)
            return C

        def test_impl_seq(n):
            A = np.arange(n)
            B = [3.0, 7.0]
            C = np.zeros(n)

            for i in range(1, n):
                C[i] = A[i - 1] * B[0] + A[i] * B[1]
            return C

        n = 100
        self.check(test_impl_seq, test_seq, n)

    @skip_unsupported
    def test_stencil_standard_indexing_2d(self):
        """Tests standard indexing with a 2d array and multiple stencil calls.
        """
        def test_seq(n):
            A = np.arange(n**2).reshape((n, n))
            B = np.ones((3, 3))
            C = stencil_with_standard_indexing_2d(A, B)
            D = stencil_with_standard_indexing_2d(C, B)
            return D

        def test_impl_seq(n):
            A = np.arange(n**2).reshape((n, n))
            B = np.ones((3, 3))
            C = np.zeros(n**2).reshape((n, n))
            D = np.zeros(n**2).reshape((n, n))

            for i in range(1, n - 1):
                for j in range(1, n - 1):
                    C[i, j] = (A[i, j + 1] * B[0, 1] + A[i + 1, j] * B[1, 0] +
                               A[i, j - 1] * B[0, -1] + A[i - 1, j] * B[-1, 0])
            for i in range(1, n - 1):
                for j in range(1, n - 1):
                    D[i, j] = (C[i, j + 1] * B[0, 1] + C[i + 1, j] * B[1, 0] +
                               C[i, j - 1] * B[0, -1] + C[i - 1, j] * B[-1, 0])
            return D

        n = 5
        self.check(test_impl_seq, test_seq, n)

    @skip_unsupported
    def test_stencil_multiple_inputs(self):
        """Tests whether multiple inputs of the same size work.
        """
        def test_seq(n):
            A = np.arange(n**2).reshape((n, n))
            B = np.arange(n**2).reshape((n, n))
            C = stencil_multiple_input_kernel(A, B)
            return C

        def test_impl_seq(n):
            A = np.arange(n**2).reshape((n, n))
            B = np.arange(n**2).reshape((n, n))
            C = np.zeros(n**2).reshape((n, n))
            for i in range(1, n - 1):
                for j in range(1, n - 1):
                    C[i, j] = 0.25 * \
                        (A[i, j + 1] + A[i + 1, j]
                         + A[i, j - 1] + A[i - 1, j]
                         + B[i, j + 1] + B[i + 1, j]
                         + B[i, j - 1] + B[i - 1, j])
            return C

        n = 3
        self.check(test_impl_seq, test_seq, n)
        # test stencil with a non-array input

        def test_seq(n):
            A = np.arange(n**2).reshape((n, n))
            B = np.arange(n**2).reshape((n, n))
            w = 0.25
            C = stencil_multiple_input_kernel_var(A, B, w)
            return C
        self.check(test_impl_seq, test_seq, n)

    @skip_unsupported
    def test_stencil_mixed_types(self):
        def test_impl_seq(n):
            A = np.arange(n ** 2).reshape((n, n))
            B = n ** 2 - np.arange(n ** 2).reshape((n, n))
            S = np.eye(n, dtype=np.bool_)
            O = np.zeros((n, n), dtype=A.dtype)
            for i in range(0, n):
                for j in range(0, n):
                    O[i, j] = A[i, j] if S[i, j] else B[i, j]
            return O

        def test_seq(n):
            A = np.arange(n ** 2).reshape((n, n))
            B = n ** 2 - np.arange(n ** 2).reshape((n, n))
            S = np.eye(n, dtype=np.bool_)
            O = stencil_multiple_input_mixed_types_2d(A, B, S)
            return O

        n = 3
        self.check(test_impl_seq, test_seq, n)

    @skip_unsupported
    def test_stencil_call(self):
        """Tests 2D numba.stencil calls.
        """
        def test_impl1(n):
            A = np.arange(n**2).reshape((n, n))
            B = np.zeros(n**2).reshape((n, n))
            numba.stencil(lambda a: 0.25 * (a[0, 1] + a[1, 0] + a[0, -1]
                                            + a[-1, 0]))(A, out=B)
            return B

        def test_impl2(n):
            A = np.arange(n**2).reshape((n, n))
            B = np.zeros(n**2).reshape((n, n))

            def sf(a):
                return 0.25 * (a[0, 1] + a[1, 0] + a[0, -1] + a[-1, 0])
            B = numba.stencil(sf)(A)
            return B

        def test_impl_seq(n):
            A = np.arange(n**2).reshape((n, n))
            B = np.zeros(n**2).reshape((n, n))
            for i in range(1, n - 1):
                for j in range(1, n - 1):
                    B[i, j] = 0.25 * (A[i, j + 1] + A[i + 1, j]
                                      + A[i, j - 1] + A[i - 1, j])
            return B

        n = 100
        self.check(test_impl_seq, test_impl1, n)
        self.check(test_impl_seq, test_impl2, n)

    @skip_unsupported
    def test_stencil_call_1D(self):
        """Tests 1D numba.stencil calls.
        """
        def test_impl(n):
            A = np.arange(n)
            B = np.zeros(n)
            numba.stencil(lambda a: 0.3 * (a[-1] + a[0] + a[1]))(A, out=B)
            return B

        def test_impl_seq(n):
            A = np.arange(n)
            B = np.zeros(n)
            for i in range(1, n - 1):
                B[i] = 0.3 * (A[i - 1] + A[i] + A[i + 1])
            return B

        n = 100
        self.check(test_impl_seq, test_impl, n)

    @skip_unsupported
    def test_stencil_call_const(self):
        """Tests numba.stencil call that has an index that can be inferred as
        constant from a unary expr. Otherwise, this would raise an error since
        neighborhood length is not specified.
        """
        def test_impl1(n):
            A = np.arange(n)
            B = np.zeros(n)
            c = 1
            numba.stencil(lambda a,c : 0.3 * (a[-c] + a[0] + a[c]))(
                                                                   A, c, out=B)
            return B

        def test_impl2(n):
            A = np.arange(n)
            B = np.zeros(n)
            c = 2
            numba.stencil(lambda a,c : 0.3 * (a[1-c] + a[0] + a[c-1]))(
                                                                   A, c, out=B)
            return B

        # recursive expr case
        def test_impl3(n):
            A = np.arange(n)
            B = np.zeros(n)
            c = 2
            numba.stencil(lambda a,c : 0.3 * (a[-c+1] + a[0] + a[c-1]))(
                                                                   A, c, out=B)
            return B

        # multi-constant case
        def test_impl4(n):
            A = np.arange(n)
            B = np.zeros(n)
            d = 1
            c = 2
            numba.stencil(lambda a,c,d : 0.3 * (a[-c+d] + a[0] + a[c-d]))(
                                                                A, c, d, out=B)
            return B

        def test_impl_seq(n):
            A = np.arange(n)
            B = np.zeros(n)
            c = 1
            for i in range(1, n - 1):
                B[i] = 0.3 * (A[i - c] + A[i] + A[i + c])
            return B

        n = 100
        # constant inference is only possible in parallel path
        cpfunc1 = self.compile_parallel(test_impl1, (types.intp,))
        cpfunc2 = self.compile_parallel(test_impl2, (types.intp,))
        cpfunc3 = self.compile_parallel(test_impl3, (types.intp,))
        cpfunc4 = self.compile_parallel(test_impl4, (types.intp,))
        expected = test_impl_seq(n)
        # parfor result
        parfor_output1 = cpfunc1.entry_point(n)
        parfor_output2 = cpfunc2.entry_point(n)
        parfor_output3 = cpfunc3.entry_point(n)
        parfor_output4 = cpfunc4.entry_point(n)
        np.testing.assert_almost_equal(parfor_output1, expected, decimal=3)
        np.testing.assert_almost_equal(parfor_output2, expected, decimal=3)
        np.testing.assert_almost_equal(parfor_output3, expected, decimal=3)
        np.testing.assert_almost_equal(parfor_output4, expected, decimal=3)

        # check error in regular Python path
        with self.assertRaises(NumbaValueError) as e:
            test_impl4(4)

        self.assertIn("stencil kernel index is not constant, "
                      "'neighborhood' option required", str(e.exception))
        # check error in njit path
        # TODO: ValueError should be thrown instead of LoweringError
        with self.assertRaises((LoweringError, NumbaValueError)) as e:
            njit(test_impl4)(4)

        self.assertIn("stencil kernel index is not constant, "
                      "'neighborhood' option required", str(e.exception))

    @skip_unsupported
    def test_stencil_parallel_off(self):
        """Tests 1D numba.stencil calls without parallel translation
           turned off.
        """
        def test_impl(A):
            return numba.stencil(lambda a: 0.3 * (a[-1] + a[0] + a[1]))(A)

        cpfunc = self.compile_parallel(test_impl, (numba.float64[:],), stencil=False)
        self.assertNotIn('@do_scheduling', cpfunc.library.get_llvm_str())

    @skip_unsupported
    def test_stencil_nested1(self):
        """Tests whether nested stencil decorator works.
        """
        @njit(parallel=True)
        def test_impl(n):
            @stencil
            def fun(a):
                c = 2
                return a[-c+1]
            B = fun(n)
            return B

        def test_impl_seq(n):
            B = np.zeros(len(n), dtype=int)
            for i in range(1, len(n)):
                B[i] = n[i-1]
            return B

        n = np.arange(10)
        np.testing.assert_equal(test_impl(n), test_impl_seq(n))

    @skip_unsupported
    def test_out_kwarg_w_cval(self):
        """ Issue #3518, out kwarg did not work with cval."""
        # test const value that matches the arg dtype, and one that can be cast
        const_vals = [7, 7.0]

        def kernel(a):
            return (a[0, 0] - a[1, 0])

        for const_val in const_vals:
            stencil_fn = numba.stencil(kernel, cval=const_val)

            def wrapped():
                A = np.arange(12).reshape((3, 4))
                ret = np.ones_like(A)
                stencil_fn(A, out=ret)
                return ret

            # stencil function case
            A = np.arange(12).reshape((3, 4))
            expected = np.full_like(A, -4)
            expected[-1, :] = const_val
            ret = np.ones_like(A)
            stencil_fn(A, out=ret)
            np.testing.assert_almost_equal(ret, expected)

            # wrapped function case, check njit, then njit(parallel=True)
            impls = self.compile_all(wrapped,)
            for impl in impls:
                got = impl.entry_point()
                np.testing.assert_almost_equal(got, expected)

        # now check exceptions for cval dtype mismatch with out kwarg dtype
        stencil_fn = numba.stencil(kernel, cval=1j)

        def wrapped():
            A = np.arange(12).reshape((3, 4))
            ret = np.ones_like(A)
            stencil_fn(A, out=ret)
            return ret

        A = np.arange(12).reshape((3, 4))
        ret = np.ones_like(A)
        with self.assertRaises(NumbaValueError) as e:
            stencil_fn(A, out=ret)
        msg = "cval type does not match stencil return type."
        self.assertIn(msg, str(e.exception))

        for compiler in [self.compile_njit, self.compile_parallel]:
            try:
                compiler(wrapped,())
            except(NumbaValueError, LoweringError) as e:
                self.assertIn(msg, str(e))
            else:
                raise AssertionError("Expected error was not raised")

    @skip_unsupported
    def test_out_kwarg_w_cval_np_attr(self):
        """ Test issue #7286 where the cval is a np attr/string-based numerical
        constant"""
        for cval in (np.nan, np.inf, -np.inf, float('inf'), -float('inf')):
            def kernel(a):
                return (a[0, 0] - a[1, 0])

            stencil_fn = numba.stencil(kernel, cval=cval)

            def wrapped():
                A = np.arange(12.).reshape((3, 4))
                ret = np.ones_like(A)
                stencil_fn(A, out=ret)
                return ret

            # stencil function case
            A = np.arange(12.).reshape((3, 4))
            expected = np.full_like(A, -4)
            expected[-1, :] = cval
            ret = np.ones_like(A)
            stencil_fn(A, out=ret)
            np.testing.assert_almost_equal(ret, expected)

            # wrapped function case, check njit, then njit(parallel=True)
            impls = self.compile_all(wrapped,)
            for impl in impls:
                got = impl.entry_point()
                np.testing.assert_almost_equal(got, expected)


class pyStencilGenerator:
    """
    Holds the classes and methods needed to generate a python stencil
    implementation from a kernel purely using AST transforms.
    """

    class Builder:
        """
        Provides code generation for the AST manipulation pipeline.
        The class methods largely produce AST nodes/trees.
        """

        def __init__(self):
            self.__state = 0

        ids = [chr(ord(v) + x) for v in ['a', 'A'] for x in range(26)]

        def varidx(self):
            """
            a monotonically increasing index for use in labelling variables.
            """
            tmp = self.__state
            self.__state = self.__state + 1
            return tmp

        # builder functions
        def gen_alloc_return(self, orig, var, dtype_var, init_val=0):
            """
            Generates an AST equivalent to:
                `var = np.full(orig.shape, init_val, dtype = dtype_var)`
            """
            new = ast.Assign(
                targets=[
                    ast.Name(
                        id=var,
                        ctx=ast.Store())],
                value=ast.Call(
                    func=ast.Attribute(
                        value=ast.Name(
                            id='np',
                            ctx=ast.Load()),
                        attr='full',
                        ctx=ast.Load()),
                    args=[
                        ast.Attribute(
                            value=ast.Name(
                                id=orig,
                                ctx=ast.Load()),
                            attr='shape',
                            ctx=ast.Load()),
                        self.gen_num(init_val)],
                    keywords=[ast.keyword(arg='dtype',
                                          value=self.gen_call('type', [dtype_var.id]).value)],
                    starargs=None,
                    kwargs=None),
            )
            return new

        def gen_assign(self, var, value, index_names):
            """
            Generates an AST equivalent to:
                `retvar[(*index_names,)] = value[<already present indexing>]`
            """
            elts_info = [ast.Name(id=x, ctx=ast.Load()) for x in index_names]
            new = ast.Assign(
                targets=[
                    ast.Subscript(
                        value=ast.Name(
                            id=var,
                            ctx=ast.Load()),
                        slice=ast.Index(
                            value=ast.Tuple(
                                elts=elts_info,
                                ctx=ast.Load())),
                        ctx=ast.Store())],
                value=value)
            return new

        def gen_loop(self, var, start=0, stop=0, body=None):
            """
            Generates an AST equivalent to a loop in `var` from
            `start` to `stop` with body `body`.
            """
            if isinstance(start, int):
                start_val = ast.Num(n=start)
            else:
                start_val = start
            if isinstance(stop, int):
                stop_val = ast.Num(n=stop)
            else:
                stop_val = stop
            return ast.For(
                target=ast.Name(id=var, ctx=ast.Store()),
                iter=ast.Call(
                    func=ast.Name(id='range', ctx=ast.Load()),
                    args=[start_val, stop_val],
                    keywords=[],
                    starargs=None, kwargs=None),
                body=body, orelse=[])

        def gen_return(self, var):
            """
            Generates an AST equivalent to `return var`
            """
            return ast.Return(value=ast.Name(id=var, ctx=ast.Load()))

        def gen_slice(self, value):
            """Generates an Index with the given value"""
            return ast.Index(value=ast.Num(n=value))

        def gen_attr(self, name, attr):
            """
            Generates AST equivalent to `name.attr`
            """
            return ast.Attribute(
                value=ast.Name(id=name, ctx=ast.Load()),
                attr=attr, ctx=ast.Load())

        def gen_subscript(self, name, attr, index, offset=None):
            """
            Generates an AST equivalent to a subscript, something like:
                name.attr[slice(index) +/- offset]
            """
            attribute = self.gen_attr(name, attr)
            slise = self.gen_slice(index)
            ss = ast.Subscript(value=attribute, slice=slise, ctx=ast.Load())
            if offset:
                pm = ast.Add() if offset >= 0 else ast.Sub()
                ss = ast.BinOp(left=ss, op=pm, right=ast.Num(n=abs(offset)))
            return ss

        def gen_num(self, value):
            """
            Generates an ast.Num of value `value`
            """
            # pretend bools are ints, ast has no boolean literal support
            if isinstance(value, bool):
                return ast.Num(int(value))
            if abs(value) >= 0:
                return ast.Num(value)
            else:
                return ast.UnaryOp(ast.USub(), ast.Num(-value))

        def gen_call(self, call_name, args, kwargs=None):
            """
            Generates an AST equivalent to a call, something like:
                `call_name(*args, **kwargs)
            """
            fixed_args = [ast.Name(id='%s' % x, ctx=ast.Load()) for x in args]
            if kwargs is not None:
                keywords = [ast.keyword(
                            arg='%s' %
                            x, value=ast.parse(str(x)).body[0].value)
                            for x in kwargs]
            else:
                keywords = []
            func = ast.Name(id=call_name, ctx=ast.Load())
            return ast.Expr(value=ast.Call(
                            func=func, args=fixed_args,
                            keywords=keywords,
                            starargs=None, kwargs=None), ctx=ast.Load())

    # AST transformers
    class FoldConst(ast.NodeTransformer, Builder):
        """
        Folds const expr, this is so const expressions in the relidx are
        more easily handled
        """

        # just support a few for testing purposes
        supported_ops = {
            ast.Add: operator.add,
            ast.Sub: operator.sub,
            ast.Mult: operator.mul,
        }

        def visit_BinOp(self, node):
            # does const expr folding
            node = self.generic_visit(node)

            op = self.supported_ops.get(node.op.__class__)
            lhs = getattr(node, 'left', None)
            rhs = getattr(node, 'right', None)

            if not (lhs and rhs and op):
                return node

            if (isinstance(lhs, ast.Num) and
                    isinstance(rhs, ast.Num)):
                return ast.Num(op(node.left.n, node.right.n))
            else:
                return node

    class FixRelIndex(ast.NodeTransformer, Builder):
        """ Fixes the relative indexes to be written in as
        induction index + relative index
        """

        def __init__(self, argnames, const_assigns,
                     standard_indexing, neighborhood, *args, **kwargs):
            ast.NodeTransformer.__init__(self, *args, **kwargs)
            pyStencilGenerator.Builder.__init__(self, *args, **kwargs)
            self._argnames = argnames
            self._const_assigns = const_assigns
            self._idx_len = -1
            self._mins = None
            self._maxes = None
            self._imin = np.iinfo(int).min
            self._imax = np.iinfo(int).max
            self._standard_indexing = standard_indexing \
                if standard_indexing else []
            self._neighborhood = neighborhood
            self._id_pat = '__%sn' if neighborhood else '__%s'

        def get_val_from_num(self, node):
            """
            Gets the literal value from a Num or UnaryOp
            """
            if isinstance(node, ast.Num):
                return node.n
            elif isinstance(node, ast.UnaryOp):
                return -node.operand.n
            else:
                raise ValueError(
                    "get_val_from_num: Unknown indexing operation")

        def visit_Subscript(self, node):
            """
            Transforms subscripts of the form `a[x]` and `a[x, y, z, ...]`
            where `x, y, z` are relative indexes, to forms such as:
            `a[x + i]` and `a[x + i, y + j, z + k]` for use in loop induced
            indexing.
            """

            def handle2dindex(node):
                idx = []
                for x, val in enumerate(node.slice.value.elts):
                    useval = self._const_assigns.get(val, val)
                    idx.append(
                        ast.BinOp(
                            left=ast.Name(
                                id=self._id_pat % self.ids[x],
                                ctx=ast.Load()),
                            op=ast.Add(),
                            right=useval,
                            ctx=ast.Load()))
                if self._idx_len == -1:
                    self._idx_len = len(idx)
                else:
                    if(self._idx_len != len(idx)):
                        raise ValueError(
                            "Relative indexing mismatch detected")
                if isinstance(node.ctx, ast.Store):
                    msg = ("Assignments to array passed to "
                            "stencil kernels is not allowed")
                    raise ValueError(msg)
                context = ast.Load()
                newnode = ast.Subscript(
                    value=node.value,
                    slice=ast.Index(
                        value=ast.Tuple(
                            elts=idx,
                            ctx=ast.Load()),
                        ctx=ast.Load()),
                    ctx=context)
                ast.copy_location(newnode, node)
                ast.fix_missing_locations(newnode)

                # now work out max/min for index ranges i.e. stencil size
                if self._mins is None and self._maxes is None:
                    # first pass
                    self._mins = [self._imax] * self._idx_len
                    self._maxes = [self._imin] * self._idx_len

                if not self._neighborhood:
                    for x, lnode in enumerate(node.slice.value.elts):
                        if isinstance(lnode, ast.Num) or\
                                isinstance(lnode, ast.UnaryOp):
                            relvalue = self.get_val_from_num(lnode)
                        elif (hasattr(lnode, 'id') and
                                lnode.id in self._const_assigns):
                            relvalue = self._const_assigns[lnode.id]
                        else:
                            raise ValueError(
                                "Cannot interpret indexing value")
                        if relvalue < self._mins[x]:
                            self._mins[x] = relvalue
                        if relvalue > self._maxes[x]:
                            self._maxes[x] = relvalue
                else:
                    for x, lnode in enumerate(self._neighborhood):
                        self._mins[x] = self._neighborhood[x][0]
                        self._maxes[x] = self._neighborhood[x][1]

                return newnode

            def handle1dindex(node):
                useval = self._const_assigns.get(
                    node.slice.value, node.slice.value)
                idx = ast.BinOp(left=ast.Name(
                                id=self._id_pat % self.ids[0],
                                ctx=ast.Load()),
                                op=ast.Add(),
                                right=useval,
                                ctx=ast.Load())
                if self._idx_len == -1:
                    self._idx_len = 1
                else:
                    if(self._idx_len != 1):
                        raise ValueError(
                            "Relative indexing mismatch detected")
                if isinstance(node.ctx, ast.Store):
                    msg = ("Assignments to array passed to "
                            "stencil kernels is not allowed")
                    raise ValueError(msg)
                context = ast.Load()
                newnode = ast.Subscript(
                    value=node.value,
                    slice=ast.Index(
                        value=idx,
                        ctx=ast.Load()),
                    ctx=context)
                ast.copy_location(newnode, node)
                ast.fix_missing_locations(newnode)

                # now work out max/min for index ranges i.e. stencil size
                if self._mins is None and self._maxes is None:
                    # first pass
                    self._mins = [self._imax, ]
                    self._maxes = [self._imin, ]

                if not self._neighborhood:
                    if isinstance(node.slice.value, ast.Num) or\
                            isinstance(node.slice.value, ast.UnaryOp):
                        relvalue = self.get_val_from_num(node.slice.value)
                    elif (hasattr(node.slice.value, 'id') and
                            node.slice.value.id in self._const_assigns):
                        relvalue = self._const_assigns[node.slice.value.id]
                    else:
                        raise ValueError("Cannot interpret indexing value")
                    if relvalue < self._mins[0]:
                        self._mins[0] = relvalue
                    if relvalue > self._maxes[0]:
                        self._maxes[0] = relvalue
                else:
                    self._mins[0] = self._neighborhood[0][0]
                    self._maxes[0] = self._neighborhood[0][1]
                ast.copy_location(newnode, node)
                ast.fix_missing_locations(newnode)
                return newnode

            def computeSlice(i, node):
                def gen_idx(val, x):
                    useval = self._const_assigns.get(val, val)
                    value = self.get_val_from_num(val)
                    tmp = ast.BinOp(
                            left=ast.Name(
                                id=self._id_pat % self.ids[x],
                                ctx=ast.Load()),
                            op=ast.Add(),
                            right=useval,
                            ctx=ast.Load())
                    ast.copy_location(tmp, node)
                    ast.fix_missing_locations(tmp)
                    return tmp
                newnode = ast.Slice(gen_idx(node.lower, i),
                                 gen_idx(node.upper, i),
                                 node.step)
                ast.copy_location(newnode, node)
                ast.fix_missing_locations(newnode)
                return newnode

            def computeIndex(i, node):
                useval = self._const_assigns.get(node.value, node.value)
                idx = ast.BinOp(left=ast.Name(
                                id=self._id_pat % self.ids[i],
                                ctx=ast.Load()),
                                op=ast.Add(),
                                right=useval,
                                ctx=ast.Load())
                newnode = ast.Index(value=idx, ctx=ast.Load())
                ast.copy_location(newnode, node)
                ast.fix_missing_locations(newnode)
                return newnode

            def handleExtSlice(node):
                idx = []
                for i, val in enumerate(node.slice.dims):
                    if isinstance(val, ast.Slice):
                        idx.append(computeSlice(i, val))
                    if isinstance(val, ast.Index):
                        idx.append(computeIndex(i, val))
                    # TODO: handle more node types
                if self._idx_len == -1:
                    self._idx_len = len(node.slice.dims)
                else:
                    if(self._idx_len != len(node.slice.dims)):
                        raise ValueError(
                            "Relative indexing mismatch detected")

                if isinstance(node.ctx, ast.Store):
                    msg = ("Assignments to array passed to "
                            "stencil kernels is not allowed")
                    raise ValueError(msg)
                context = ast.Load()
                newnode = ast.Subscript(
                    value=node.value,
                    slice=ast.ExtSlice(
                        dims=idx,
                        ctx=ast.Load()),
                        ctx=context
                        )

                # now work out max/min for index ranges i.e. stencil size
                if self._mins is None and self._maxes is None:
                    # first pass
                    self._mins = [self._imax] * self._idx_len
                    self._maxes = [self._imin] * self._idx_len

                if not self._neighborhood:
                    for x, anode in enumerate(node.slice.dims):
                        if isinstance(anode, ast.Slice):
                            for lnode in [anode.lower, anode.upper]:
                                if isinstance(lnode, ast.Num) or\
                                        isinstance(lnode, ast.UnaryOp):
                                    relvalue = self.get_val_from_num(lnode)
                                elif (hasattr(lnode, 'id') and
                                        lnode.id in self._const_assigns):
                                    relvalue = self._const_assigns[lnode.id]
                                else:
                                    raise ValueError(
                                        "Cannot interpret indexing value")
                                if relvalue < self._mins[x]:
                                    self._mins[x] = relvalue
                                if relvalue > self._maxes[x]:
                                    self._maxes[x] = relvalue
                        else:
                            val = anode.value
                            if isinstance(val, ast.Num) or\
                                    isinstance(val, ast.UnaryOp):
                                relvalue = self.get_val_from_num(val)
                            elif (hasattr(val, 'id') and
                                    val.id in self._const_assigns):
                                relvalue = self._const_assigns[val.id]
                            else:
                                raise ValueError(
                                    "Cannot interpret indexing value")
                            if relvalue < self._mins[x]:
                                self._mins[x] = relvalue
                            if relvalue > self._maxes[x]:
                                self._maxes[x] = relvalue
                else:
                    for x, lnode in enumerate(self._neighborhood):
                        self._mins[x] = self._neighborhood[x][0]
                        self._maxes[x] = self._neighborhood[x][1]
                ast.copy_location(newnode, node)
                ast.fix_missing_locations(newnode)
                return newnode


            def handleSlice(node):
                idx = computeSlice(0, node.slice)
                idx.ctx=ast.Load()
                if isinstance(node.ctx, ast.Store):
                    msg = ("Assignments to array passed to "
                            "stencil kernels is not allowed")
                    raise ValueError(msg)
                context = ast.Load()
                newnode = ast.Subscript(
                    value=node.value,
                    slice=idx,
                    ctx=context)
                ast.copy_location(newnode, node)
                ast.fix_missing_locations(newnode)

                if self._idx_len == -1:
                    self._idx_len = 1
                else:
                    if(self._idx_len != 1):
                        raise ValueError(
                            "Relative indexing mismatch detected")

                # now work out max/min for index ranges i.e. stencil size
                if self._mins is None and self._maxes is None:
                    # first pass
                    self._mins = [self._imax]
                    self._maxes = [self._imin]

                if not self._neighborhood:
                    if isinstance(node.slice.value, ast.Num) or\
                            isinstance(node.slice.value, ast.UnaryOp):
                        relvalue = self.get_val_from_num(node.slice.value)
                    elif (hasattr(node.slice.value, 'id') and
                            node.slice.value.id in self._const_assigns):
                        relvalue = self._const_assigns[node.slice.value.id]
                    else:
                        raise ValueError("Cannot interpret indexing value")
                    if relvalue < self._mins[0]:
                        self._mins[0] = relvalue
                    if relvalue > self._maxes[0]:
                        self._maxes[0] = relvalue
                else:
                    self._mins[0] = self._neighborhood[0][0]
                    self._maxes[0] = self._neighborhood[0][1]

                return newnode


            node = self.generic_visit(node)
            if (node.value.id in self._argnames) and (
                    node.value.id not in self._standard_indexing):

                # fancy slice
                if isinstance(node.slice, ast.ExtSlice):
                    return handleExtSlice(node)
                # plain slice
                if isinstance(node.slice, ast.Slice):
                    return handleSlice(node)
                # 2D index
                if isinstance(node.slice.value, ast.Tuple):
                    return handle2dindex(node)
                # 1D index
                elif isinstance(node.slice, ast.Index):
                    return handle1dindex(node)
                else:  # unknown
                    raise ValueError("Unhandled subscript")
            else:
                return node

        @property
        def idx_len(self):
            if self._idx_len == -1:
                raise ValueError(
                    'Transform has not been run/no indexes found')
            else:
                return self._idx_len

        @property
        def maxes(self):
            return self._maxes

        @property
        def mins(self):
            return self._mins

        @property
        def id_pattern(self):
            return self._id_pat

    class TransformReturns(ast.NodeTransformer, Builder):
        """
        Transforms return nodes into assignments.
        """

        def __init__(self, relidx_info, *args, **kwargs):
            ast.NodeTransformer.__init__(self, *args, **kwargs)
            pyStencilGenerator.Builder.__init__(self, *args, **kwargs)
            self._relidx_info = relidx_info
            self._ret_var_idx = self.varidx()
            retvar = '__b%s' % self._ret_var_idx
            self._retvarname = retvar

        def visit_Return(self, node):
            self.generic_visit(node)
            nloops = self._relidx_info.idx_len
            var_pattern = self._relidx_info.id_pattern
            return self.gen_assign(
                self._retvarname, node.value,
                [var_pattern % self.ids[l] for l in range(nloops)])

        @property
        def ret_var_name(self):
            return self._retvarname

    class FixFunc(ast.NodeTransformer, Builder):
        """ The main function rewriter, takes the body of the kernel and generates:
         * checking function calls
         * return value allocation
         * loop nests
         * return site
         * Function definition as an entry point
        """

        def __init__(self, kprops, relidx_info, ret_info,
                     cval, standard_indexing, neighborhood, *args, **kwargs):
            ast.NodeTransformer.__init__(self, *args, **kwargs)
            pyStencilGenerator.Builder.__init__(self, *args, **kwargs)
            self._original_kernel = kprops.original_kernel
            self._argnames = kprops.argnames
            self._retty = kprops.retty
            self._relidx_info = relidx_info
            self._ret_info = ret_info
            self._standard_indexing = standard_indexing \
                if standard_indexing else []
            self._neighborhood = neighborhood if neighborhood else tuple()
            self._relidx_args = [
                x for x in self._argnames if x not in self._standard_indexing]
            # switch cval to python type
            if hasattr(cval, 'dtype'):
                self.cval = cval.tolist()
            else:
                self.cval = cval
            self.stencil_arr = self._argnames[0]

        def visit_FunctionDef(self, node):
            """
            Transforms the kernel function into a function that will perform
            the stencil like behaviour on the kernel.
            """
            self.generic_visit(node)

            # this function validates arguments and is injected into the top
            # of the stencil call
            def check_stencil_arrays(*args, **kwargs):
                # the first has to be an array due to parfors requirements
                neighborhood = kwargs.get('neighborhood')
                init_shape = args[0].shape
                if neighborhood is not None:
                    if len(init_shape) != len(neighborhood):
                        raise ValueError("Invalid neighborhood supplied")
                for x in args[1:]:
                    if hasattr(x, 'shape'):
                        if init_shape != x.shape:
                            raise ValueError(
                                "Input stencil arrays do not commute")

            checksrc = inspect.getsource(check_stencil_arrays)
            check_impl = ast.parse(
                checksrc.strip()).body[0]  # don't need module
            ast.fix_missing_locations(check_impl)

            checker_call = self.gen_call(
                'check_stencil_arrays',
                self._relidx_args,
                kwargs=['neighborhood'])

            nloops = self._relidx_info.idx_len

            def computebound(mins, maxs):
                minlim = 0 if mins >= 0 else -mins
                maxlim = -maxs if maxs > 0 else 0
                return (minlim, maxlim)

            var_pattern = self._relidx_info.id_pattern

            loop_body = node.body

            # create loop nests
            loop_count = 0
            for l in range(nloops):
                minlim, maxlim = computebound(
                    self._relidx_info.mins[loop_count],
                    self._relidx_info.maxes[loop_count])
                minbound = minlim
                maxbound = self.gen_subscript(
                    self.stencil_arr, 'shape', loop_count, maxlim)
                loops = self.gen_loop(
                    var_pattern % self.ids[loop_count],
                    minbound, maxbound, body=loop_body)
                loop_body = [loops]
                loop_count += 1

            # patch loop location
            ast.copy_location(loops, node)
            _rettyname = self._retty.targets[0]

            # allocate a return
            retvar = self._ret_info.ret_var_name
            allocate = self.gen_alloc_return(
                self.stencil_arr, retvar, _rettyname, self.cval)
            ast.copy_location(allocate, node)

            # generate the return
            returner = self.gen_return(retvar)
            ast.copy_location(returner, node)

            add_kwarg = [ast.arg('neighborhood', None)]
            defaults = []

            newargs = ast.arguments(
                args=node.args.args +
                add_kwarg,
                defaults=defaults,
                vararg=None,
                kwarg=None,
                kwonlyargs=[],
                kw_defaults=[],
                posonlyargs=[])
            new = ast.FunctionDef(
                name='__%s' %
                node.name,
                args=newargs,
                body=[
                    check_impl,
                    checker_call,
                    self._original_kernel,
                    self._retty,
                    allocate,
                    loops,
                    returner],
                decorator_list=[])
            ast.copy_location(new, node)
            return new

    class GetKernelProps(ast.NodeVisitor, Builder):
        """ Gets the argument names and other properties
        of the original kernel.
        """

        def __init__(self, *args, **kwargs):
            ast.NodeVisitor.__init__(self, *args, **kwargs)
            pyStencilGenerator.Builder.__init__(self, *args, **kwargs)
            self._argnames = None
            self._kwargnames = None
            self._retty = None
            self._original_kernel = None
            self._const_assigns = {}

        def visit_FunctionDef(self, node):
            if self._argnames is not None or self._kwargnames is not None:
                raise RuntimeError("multiple definition of function/args?")

            attr = 'arg'
            self._argnames = [getattr(x, attr) for x in node.args.args]
            if node.args.kwarg:
                self._kwargnames = [x.arg for x in node.args.kwarg]
            compute_retdtype = self.gen_call(node.name, self._argnames)
            self._retty = ast.Assign(targets=[ast.Name(
                id='__retdtype',
                ctx=ast.Store())], value=compute_retdtype.value)
            self._original_kernel = ast.fix_missing_locations(deepcopy(node))
            self.generic_visit(node)

        def visit_Assign(self, node):
            self.generic_visit(node)
            tgt = node.targets
            if len(tgt) == 1:
                target = tgt[0]
                if isinstance(target, ast.Name):
                    if isinstance(node.value, ast.Num):
                        self._const_assigns[target.id] = node.value.n
                    elif isinstance(node.value, ast.UnaryOp):
                        if isinstance(node.value, ast.UAdd):
                            self._const_assigns[target.id] = node.value.n
                        else:
                            self._const_assigns[target.id] = -node.value.n

        @property
        def argnames(self):
            """
            The names of the arguments to the function
            """
            return self._argnames

        @property
        def const_assigns(self):
            """
            A map of variable name to constant for variables that are simple
            constant assignments
            """
            return self._const_assigns

        @property
        def retty(self):
            """
            The return type
            """
            return self._retty

        @property
        def original_kernel(self):
            """
            The original unmutated kernel
            """
            return self._original_kernel

    class FixCalls(ast.NodeTransformer):
        """ Fixes call sites for astor (in case it is in use) """

        def visit_Call(self, node):
            self.generic_visit(node)
            # Add in starargs and kwargs to calls
            new = ast.Call(
                func=node.func,
                args=node.args,
                keywords=node.keywords,
                starargs=None,
                kwargs=None)
            return new

    def generate_stencil_tree(
            self, func, cval, standard_indexing, neighborhood):
        """
        Generates the AST tree for a stencil from:
        func - a python stencil kernel
        cval, standard_indexing and neighborhood as per the @stencil decorator
        """
        src = inspect.getsource(func)
        tree = ast.parse(src.strip())

        # Prints debugging information if True.
        # If astor is installed the decompilation of the AST is also printed
        DEBUG = False
        if DEBUG:
            print("ORIGINAL")
            print(ast.dump(tree))

        def pipeline(tree):
            """ the pipeline of manipulations """

            # get the arg names
            kernel_props = self.GetKernelProps()
            kernel_props.visit(tree)
            argnm = kernel_props.argnames
            const_asgn = kernel_props.const_assigns

            if standard_indexing:
                for x in standard_indexing:
                    if x not in argnm:
                        msg = ("Non-existent variable "
                               "specified in standard_indexing")
                        raise ValueError(msg)

            # fold consts
            fold_const = self.FoldConst()
            fold_const.visit(tree)

            # rewrite the relative indices as induced indices
            relidx_fixer = self.FixRelIndex(
                argnm, const_asgn, standard_indexing, neighborhood)
            relidx_fixer.visit(tree)

            # switch returns into assigns
            return_transformer = self.TransformReturns(relidx_fixer)
            return_transformer.visit(tree)

            # generate the function body and loop nests and assemble
            fixer = self.FixFunc(
                kernel_props,
                relidx_fixer,
                return_transformer,
                cval,
                standard_indexing,
                neighborhood)
            fixer.visit(tree)

            # fix up the call sites so they work better with astor
            callFixer = self.FixCalls()
            callFixer.visit(tree)
            ast.fix_missing_locations(tree.body[0])

        # run the pipeline of transforms on the tree
        pipeline(tree)

        if DEBUG:
            print("\n\n\nNEW")
            print(ast.dump(tree, include_attributes=True))
            try:
                import astor
                print(astor.to_source(tree))
            except ImportError:
                pass

        return tree


def pyStencil(func_or_mode='constant', **options):
    """
    A pure python implementation of (a large subset of) stencil functionality,
    equivalent to StencilFunc.
    """

    if not isinstance(func_or_mode, str):
        mode = 'constant'  # default style
        func = func_or_mode
    else:
        assert isinstance(func_or_mode, str), """stencil mode should be
                                                        a string"""
        mode = func_or_mode
        func = None

    for option in options:
        if option not in ["cval", "standard_indexing", "neighborhood"]:
            raise ValueError("Unknown stencil option " + option)

    if mode != 'constant':
        raise ValueError("Unsupported mode style " + mode)

    cval = options.get('cval', 0)
    standard_indexing = options.get('standard_indexing', None)
    neighborhood = options.get('neighborhood', None)

    # generate a new AST tree from the kernel func
    gen = pyStencilGenerator()
    tree = gen.generate_stencil_tree(func, cval, standard_indexing,
                                     neighborhood)

    # breathe life into the tree
    mod_code = compile(tree, filename="<ast>", mode="exec")
    func_code = mod_code.co_consts[0]
    full_func = pytypes.FunctionType(func_code, globals())

    return full_func


@skip_unsupported
class TestManyStencils(TestStencilBase):

    def __init__(self, *args, **kwargs):
        super(TestManyStencils, self).__init__(*args, **kwargs)

    def check(self, pyfunc, *args, **kwargs):
        """
        For a given kernel:

        The expected result is computed from a pyStencil version of the
        stencil.

        The following results are then computed:
        * from a pure @stencil decoration of the kernel.
        * from the njit of a trivial wrapper function around the pure @stencil
          decorated function.
        * from the njit(parallel=True) of a trivial wrapper function around
           the pure @stencil decorated function.

        The results are then compared.
        """

        options = kwargs.get('options', dict())
        expected_exception = kwargs.get('expected_exception')

        # DEBUG print output arrays
        DEBUG_OUTPUT = False

        # collect fails
        should_fail = []
        should_not_fail = []

        # runner that handles fails
        @contextmanager
        def errorhandler(exty=None, usecase=None):
            try:
                yield
            except Exception as e:
                if exty is not None:
                    lexty = exty if hasattr(exty, '__iter__') else [exty, ]
                    found = False
                    for ex in lexty:
                        found |= isinstance(e, ex)
                    if not found:
                        raise
                else:
                    should_not_fail.append(
                        (usecase, "%s: %s" %
                         (type(e), str(e))))
            else:
                if exty is not None:
                    should_fail.append(usecase)

        if isinstance(expected_exception, dict):
            pystencil_ex = expected_exception['pyStencil']
            stencil_ex = expected_exception['stencil']
            njit_ex = expected_exception['njit']
            parfor_ex = expected_exception['parfor']
        else:
            pystencil_ex = expected_exception
            stencil_ex = expected_exception
            njit_ex = expected_exception
            parfor_ex = expected_exception

        stencil_args = {'func_or_mode': pyfunc}
        stencil_args.update(options)

        expected_present = True

        try:
            # ast impl
            ast_impl = pyStencil(func_or_mode=pyfunc, **options)
            expected = ast_impl(
                *args, neighborhood=options.get('neighborhood'))
            if DEBUG_OUTPUT:
                print("\nExpected:\n", expected)
        except Exception as ex:
            # check exception is expected
            with errorhandler(pystencil_ex, "pyStencil"):
                raise ex
            pyStencil_unhandled_ex = ex
            expected_present = False
        stencilfunc_output = None
        with errorhandler(stencil_ex, "@stencil"):
            stencil_func_impl = stencil(**stencil_args)
            # stencil result
            stencilfunc_output = stencil_func_impl(*args)

        # wrapped stencil impl, could this be generated?
        if len(args) == 1:
            def wrap_stencil(arg0):
                return stencil_func_impl(arg0)
        elif len(args) == 2:
            def wrap_stencil(arg0, arg1):
                return stencil_func_impl(arg0, arg1)
        elif len(args) == 3:
            def wrap_stencil(arg0, arg1, arg2):
                return stencil_func_impl(arg0, arg1, arg2)
        else:
            raise ValueError(
                "Up to 3 arguments can be provided, found %s" %
                len(args))

        sig = tuple([numba.typeof(x) for x in args])

        njit_output = None
        with errorhandler(njit_ex, "njit"):
            wrapped_cfunc = self.compile_njit(wrap_stencil, sig)
            # njit result
            njit_output = wrapped_cfunc.entry_point(*args)

        parfor_output = None
        with errorhandler(parfor_ex, "parfors"):
            wrapped_cpfunc = self.compile_parallel(wrap_stencil, sig)
            # parfor result
            parfor_output = wrapped_cpfunc.entry_point(*args)

        if DEBUG_OUTPUT:
            print("\n@stencil_output:\n", stencilfunc_output)
            print("\nnjit_output:\n", njit_output)
            print("\nparfor_output:\n", parfor_output)

        if expected_present:
            try:
                if not stencil_ex:
                    np.testing.assert_almost_equal(
                        stencilfunc_output, expected, decimal=1)
                    self.assertEqual(expected.dtype, stencilfunc_output.dtype)
            except Exception as e:
                should_not_fail.append(
                    ('@stencil', "%s: %s" %
                     (type(e), str(e))))
                print("@stencil failed: %s" % str(e))

            try:
                if not njit_ex:
                    np.testing.assert_almost_equal(
                        njit_output, expected, decimal=1)
                    self.assertEqual(expected.dtype, njit_output.dtype)
            except Exception as e:
                should_not_fail.append(('njit', "%s: %s" % (type(e), str(e))))
                print("@njit failed: %s" % str(e))

            try:
                if not parfor_ex:
                    np.testing.assert_almost_equal(
                        parfor_output, expected, decimal=1)
                    self.assertEqual(expected.dtype, parfor_output.dtype)
                    try:
                        self.assertIn(
                            '@do_scheduling',
                            wrapped_cpfunc.library.get_llvm_str())
                    except AssertionError:
                        msg = 'Could not find `@do_scheduling` in LLVM IR'
                        raise AssertionError(msg)
            except Exception as e:
                should_not_fail.append(
                    ('parfors', "%s: %s" %
                     (type(e), str(e))))
                print("@njit(parallel=True) failed: %s" % str(e))

        if DEBUG_OUTPUT:
            print("\n\n")

        if should_fail:
            msg = ["%s" % x for x in should_fail]
            raise RuntimeError(("The following implementations should have "
                                "raised an exception but did not:\n%s") % msg)

        if should_not_fail:
            impls = ["%s" % x[0] for x in should_not_fail]
            errs = ''.join(["%s: Message: %s\n\n" %
                            x for x in should_not_fail])
            str1 = ("The following implementations should not have raised an "
                    "exception but did:\n%s\n" % impls)
            str2 = "Errors were:\n\n%s" % errs
            raise RuntimeError(str1 + str2)

        if not expected_present:
            if expected_exception is None:
                raise RuntimeError(
                    "pyStencil failed, was not caught/expected",
                    pyStencil_unhandled_ex)

    def exception_dict(self, **kwargs):
        d = dict()
        d['pyStencil'] = None
        d['stencil'] = None
        d['njit'] = None
        d['parfor'] = None
        for k, v in kwargs.items():
            d[k] = v
        return d

    def test_basic00(self):
        """rel index"""
        def kernel(a):
            return a[0, 0]
        a = np.arange(12).reshape(3, 4)
        self.check(kernel, a)

    def test_basic01(self):
        """rel index add const"""
        def kernel(a):
            return a[0, 1]
        a = np.arange(12.).reshape(3, 4)
        self.check(kernel, a)

    def test_basic02(self):
        """rel index add const"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[0, -1]
        self.check(kernel, a)

    def test_basic03(self):
        """rel index add const"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[1, 0]
        self.check(kernel, a)

    def test_basic04(self):
        """rel index add const"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[-1, 0]
        self.check(kernel, a)

    def test_basic05(self):
        """rel index add const"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[-1, 1]
        self.check(kernel, a)

    def test_basic06(self):
        """rel index add const"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[1, -1]
        self.check(kernel, a)

    def test_basic07(self):
        """rel index add const"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[1, 1]
        self.check(kernel, a)

    def test_basic08(self):
        """rel index add const"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[-1, -1]
        self.check(kernel, a)

    def test_basic09(self):
        """rel index add const"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[-2, 2]
        self.check(kernel, a)

    def test_basic10(self):
        """rel index add const"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[0, 0] + a[1, 0]
        self.check(kernel, a)

    def test_basic11(self):
        """rel index add const"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[-1, 0] + a[1, 0]
        self.check(kernel, a)

    def test_basic12(self):
        """rel index add const"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[-1, 1] + a[1, -1]
        self.check(kernel, a)

    def test_basic13(self):
        """rel index add const"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[-1, -1] + a[1, 1]
        self.check(kernel, a)

    def test_basic14(self):
        """rel index add domain change const"""
        a = np.arange(12).reshape(3, 4)

        def kernel(a):
            return a[0, 0] + 1j
        self.check(kernel, a)

    def test_basic14b(self):
        """rel index add domain change const"""
        a = np.arange(12).reshape(3, 4)

        def kernel(a):
            t = 1.j
            return a[0, 0] + t
        self.check(kernel, a)

    def test_basic15(self):
        """two rel index, add const"""
        a = np.arange(12).reshape(3, 4)

        def kernel(a):
            return a[0, 0] + a[1, 0] + 1.
        self.check(kernel, a)

    def test_basic16(self):
        """two rel index OOB, add const"""
        a = np.arange(12).reshape(3, 4)

        def kernel(a):
            return a[0, 0] + a[10, 0] + 1.

        # only pyStencil bounds checks
        ex = self.exception_dict(pyStencil=IndexError)
        self.check(kernel, a, expected_exception=ex)

    def test_basic17(self):
        """two rel index boundary test, add const"""
        a = np.arange(12).reshape(3, 4)

        def kernel(a):
            return a[0, 0] + a[2, 0] + 1.
        self.check(kernel, a)

    def test_basic18(self):
        """two rel index boundary test, add const"""
        a = np.arange(12).reshape(3, 4)

        def kernel(a):
            return a[0, 0] + a[-2, 0] + 1.
        self.check(kernel, a)

    def test_basic19(self):
        """two rel index boundary test, add const"""
        a = np.arange(12).reshape(3, 4)

        def kernel(a):
            return a[0, 0] + a[0, 3] + 1.
        self.check(kernel, a)

    def test_basic20(self):
        """two rel index boundary test, add const"""
        a = np.arange(12).reshape(3, 4)

        def kernel(a):
            return a[0, 0] + a[0, -3] + 1.
        self.check(kernel, a)

    def test_basic21(self):
        """same rel, add const"""
        a = np.arange(12).reshape(3, 4)

        def kernel(a):
            return a[0, 0] + a[0, 0] + 1.
        self.check(kernel, a)

    def test_basic22(self):
        """rel idx const expr folding, add const"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[1 + 0, 0] + a[0, 0] + 1.
        self.check(kernel, a)

    def test_basic23(self):
        """rel idx, work in body"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            x = np.sin(10 + a[2, 1])
            return a[1 + 0, 0] + a[0, 0] + x
        self.check(kernel, a)

    def test_basic23a(self):
        """rel idx, dead code should not impact rel idx"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            x = np.sin(10 + a[2, 1])
            return a[1 + 0, 0] + a[0, 0]
        self.check(kernel, a)

    def test_basic24(self):
        """1d idx on 2d arr"""
        a = np.arange(12).reshape(3, 4)

        def kernel(a):
            return a[0] + 1.
        self.check(kernel, a, expected_exception=[ValueError, TypingError])

    def test_basic25(self):
        """no idx on 2d arr"""
        a = np.arange(12).reshape(3, 4)

        def kernel(a):
            return 1.
        self.check(kernel, a, expected_exception=[ValueError, NumbaValueError])

    def test_basic26(self):
        """3d arr"""
        a = np.arange(64).reshape(4, 8, 2)

        def kernel(a):
            return a[0, 0, 0] - a[0, 1, 0] + 1.
        self.check(kernel, a)

    def test_basic27(self):
        """4d arr"""
        a = np.arange(128).reshape(4, 8, 2, 2)

        def kernel(a):
            return a[0, 0, 0, 0] - a[0, 1, 0, -1] + 1.
        self.check(kernel, a)

    def test_basic28(self):
        """type widen """
        a = np.arange(12).reshape(3, 4).astype(np.float32)

        def kernel(a):
            return a[0, 0] + np.float64(10.)
        self.check(kernel, a)

    def test_basic29(self):
        """const index from func """
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[0, int(np.cos(0))]
        self.check(kernel, a, expected_exception=[ValueError, NumbaValueError,
                                                  LoweringError])

    def test_basic30(self):
        """signed zeros"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[-0, -0]
        self.check(kernel, a)

    def test_basic31(self):
        """does a const propagate? 2D"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            t = 1
            return a[t, 0]
        self.check(kernel, a)

    @unittest.skip("constant folding not implemented")
    def test_basic31b(self):
        """does a const propagate?"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            s = 1
            t = 1 - s
            return a[t, 0]
        self.check(kernel, a)

    def test_basic31c(self):
        """does a const propagate? 1D"""
        a = np.arange(12.)

        def kernel(a):
            t = 1
            return a[t]
        self.check(kernel, a)

    def test_basic32(self):
        """typed int index"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[np.int8(1), 0]
        self.check(kernel, a, expected_exception=[ValueError, NumbaValueError,
                                                  LoweringError])

    def test_basic33(self):
        """add 0d array"""
        a = np.arange(12.).reshape(3, 4)

        def kernel(a):
            return a[0, 0] + np.array(1)
        self.check(kernel, a)

    def test_basic34(self):
        """More complex rel index with dependency on addition rel index"""
        def kernel(a):
            g = 4. + a[0, 1]
            return g + (a[0, 1] + a[1, 0] + a[0, -1] + np.sin(a[-2, 0]))
        a = np.arange(144).reshape(12, 12)
        self.check(kernel, a)

    def test_basic35(self):
        """simple cval where cval is int but castable to dtype of float"""
        def kernel(a):
            return a[0, 1]
        a = np.arange(12.).reshape(3, 4)
        self.check(kernel, a, options={'cval': 5})

    def test_basic36(self):
        """more complex with cval"""
        def kernel(a):
            return a[0, 1] + a[0, -1] + a[1, -1] + a[1, -1]
        a = np.arange(12.).reshape(3, 4)
        self.check(kernel, a, options={'cval': 5.})

    def test_basic37(self):
        """cval is expr"""
        def kernel(a):
            return a[0, 1] + a[0, -1] + a[1, -1] + a[1, -1]
        a = np.arange(12.).reshape(3, 4)
        self.check(kernel, a, options={'cval': 5 + 63.})

    def test_basic38(self):
        """cval is complex"""
        def kernel(a):
            return a[0, 1] + a[0, -1] + a[1, -1] + a[1, -1]
        a = np.arange(12.).reshape(3, 4)
        ex = self.exception_dict(
            stencil=NumbaValueError,
            parfor=ValueError,
            njit=NumbaValueError)
        self.check(kernel, a, options={'cval': 1.j}, expected_exception=ex)

    def test_basic39(self):
        """cval is func expr"""
        def kernel(a):
            return a[0, 1] + a[0, -1] + a[1, -1] + a[1, -1]
        a = np.arange(12.).reshape(3, 4)
        self.check(kernel, a, options={'cval': np.sin(3.) + np.cos(2)})

    def test_basic40(self):
        """2 args!"""
        def kernel(a, b):
            return a[0, 1] + b[0, -2]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(kernel, a, b)

    def test_basic41(self):
        """2 args! rel arrays wildly not same size!"""
        def kernel(a, b):
            return a[0, 1] + b[0, -2]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(1.).reshape(1, 1)
        self.check(
            kernel, a, b, expected_exception=[
                ValueError, AssertionError])

    def test_basic42(self):
        """2 args! rel arrays very close in size"""
        def kernel(a, b):
            return a[0, 1] + b[0, -2]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(9.).reshape(3, 3)
        self.check(
            kernel, a, b, expected_exception=[
                ValueError, AssertionError])

    def test_basic43(self):
        """2 args more complexity"""
        def kernel(a, b):
            return a[0, 1] + a[1, 2] + b[-2, 0] + b[0, -1]
        a = np.arange(30.).reshape(5, 6)
        b = np.arange(30.).reshape(5, 6)
        self.check(kernel, a, b)

    def test_basic44(self):
        """2 args, has assignment before use"""
        def kernel(a, b):
            a[0, 1] = 12
            return a[0, 1]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(
            kernel, a, b, expected_exception=[
                ValueError, LoweringError])

    def test_basic45(self):
        """2 args, has assignment and then cross dependency"""
        def kernel(a, b):
            a[0, 1] = 12
            return a[0, 1] + a[1, 0]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(
            kernel, a, b, expected_exception=[
                ValueError, LoweringError])

    def test_basic46(self):
        """2 args, has cross relidx assignment"""
        def kernel(a, b):
            a[0, 1] = b[1, 2]
            return a[0, 1] + a[1, 0]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(
            kernel, a, b, expected_exception=[
                ValueError, LoweringError])

    def test_basic47(self):
        """3 args"""
        def kernel(a, b, c):
            return a[0, 1] + b[1, 0] + c[-1, 0]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        c = np.arange(12.).reshape(3, 4)
        self.check(kernel, a, b, c)

    # matches pyStencil, but all ought to fail
    # probably hard to detect?
    def test_basic48(self):
        """2 args, has assignment before use via memory alias"""
        def kernel(a):
            c = a.T
            c[:, :] = 10
            return a[0, 1]
        a = np.arange(12.).reshape(3, 4)
        self.check(kernel, a)

    def test_basic49(self):
        """2 args, standard_indexing on second"""
        def kernel(a, b):
            return a[0, 1] + b[0, 3]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(kernel, a, b, options={'standard_indexing': 'b'})

    @unittest.skip("dynamic range checking not implemented")
    def test_basic50(self):
        """2 args, standard_indexing OOB"""
        def kernel(a, b):
            return a[0, 1] + b[0, 15]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(
            kernel,
            a,
            b,
            options={
                'standard_indexing': 'b'},
            expected_exception=IndexError)

    def test_basic51(self):
        """2 args, standard_indexing, no relidx"""
        def kernel(a, b):
            return a[0, 1] + b[0, 2]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(
            kernel, a, b, options={
                'standard_indexing': [
                    'a', 'b']}, expected_exception=[
                ValueError, NumbaValueError])

    def test_basic52(self):
        """3 args, standard_indexing on middle arg """
        def kernel(a, b, c):
            return a[0, 1] + b[0, 1] + c[1, 2]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(4.).reshape(2, 2)
        c = np.arange(12.).reshape(3, 4)
        self.check(kernel, a, b, c, options={'standard_indexing': 'b'})

    def test_basic53(self):
        """2 args, standard_indexing on variable that does not exist"""
        def kernel(a, b):
            return a[0, 1] + b[0, 2]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        ex = self.exception_dict(
            pyStencil=ValueError,
            stencil=Exception,
            parfor=ValueError,
            njit=Exception)
        self.check(
            kernel,
            a,
            b,
            options={
                'standard_indexing': 'c'},
            expected_exception=ex)

    def test_basic54(self):
        """2 args, standard_indexing, index from var"""
        def kernel(a, b):
            t = 2
            return a[0, 1] + b[0, t]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(kernel, a, b, options={'standard_indexing': 'b'})

    def test_basic55(self):
        """2 args, standard_indexing, index from more complex var"""
        def kernel(a, b):
            s = 1
            t = 2 - s
            return a[0, 1] + b[0, t]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(kernel, a, b, options={'standard_indexing': 'b'})

    def test_basic56(self):
        """2 args, standard_indexing, added complexity """
        def kernel(a, b):
            s = 1
            acc = 0
            for k in b[0, :]:
                acc += k
            t = 2 - s - 1
            return a[0, 1] + b[0, t] + acc
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(kernel, a, b, options={'standard_indexing': 'b'})

    def test_basic57(self):
        """2 args, standard_indexing, split index operation """
        def kernel(a, b):
            c = b[0]
            return a[0, 1] + c[1]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(kernel, a, b, options={'standard_indexing': 'b'})

    def test_basic58(self):
        """2 args, standard_indexing, split index with broadcast mutation """
        def kernel(a, b):
            c = b[0] + 1
            return a[0, 1] + c[1]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(kernel, a, b, options={'standard_indexing': 'b'})

    def test_basic59(self):
        """3 args, mix of array, relative and standard indexing and const"""
        def kernel(a, b, c):
            return a[0, 1] + b[1, 1] + c
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        c = 10
        self.check(kernel, a, b, c, options={'standard_indexing': ['b', 'c']})

    def test_basic60(self):
        """3 args, mix of array, relative and standard indexing,
        tuple pass through"""
        def kernel(a, b, c):
            return a[0, 1] + b[1, 1] + c[0]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        c = (10,)
        # parfors does not support tuple args for stencil kernels
        ex = self.exception_dict(parfor=ValueError)
        self.check(
            kernel, a, b, c, options={
                'standard_indexing': [
                    'b', 'c']}, expected_exception=ex)

    def test_basic61(self):
        """2 args, standard_indexing on first"""
        def kernel(a, b):
            return a[0, 1] + b[1, 1]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(
            kernel,
            a,
            b,
            options={
                'standard_indexing': 'a'},
            expected_exception=Exception)

    def test_basic62(self):
        """2 args, standard_indexing and cval"""
        def kernel(a, b):
            return a[0, 1] + b[1, 1]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12.).reshape(3, 4)
        self.check(
            kernel,
            a,
            b,
            options={
                'standard_indexing': 'b',
                'cval': 10.})

    def test_basic63(self):
        """2 args, standard_indexing applied to relative, should fail,
        non-const idx"""
        def kernel(a, b):
            return a[0, b[0, 1]]
        a = np.arange(12.).reshape(3, 4)
        b = np.arange(12).reshape(3, 4)
        ex = self.exception_dict(
            pyStencil=ValueError,
            stencil=NumbaValueError,
            parfor=ValueError,
            njit=NumbaValueError)
        self.check(
            kernel,
            a,
            b,
            options={
                'standard_indexing': 'b'},
            expected_exception=ex)

    # stencil, njit, parfors all fail. Does this make sense?
    def test_basic64(self):
        """1 arg that uses standard_indexing"""
        def kernel(a):
            return a[0, 0]
        a = np.arange(12.).reshape(3, 4)
        self.check(
            kernel,
            a,
            options={
                'standard_indexing': 'a'},
            expected_exception=[
                ValueError,
                NumbaValueError])

    def test_basic65(self):
        """basic induced neighborhood test"""
        def kernel(a):
            cumul = 0
            for i in range(-29, 1):
                cumul += a[i]
            return cumul / 30
        a = np.arange(60.)
        self.check(kernel, a, options={'neighborhood': ((-29, 0),)})

    # Should this work? a[0] is out of neighborhood?
    def test_basic66(self):
        """basic const neighborhood test"""
        def kernel(a):
            cumul = 0
            for i in range(-29, 1):
                cumul += a[0]
            return cumul / 30
        a = np.arange(60.)
        self.check(kernel, a, options={'neighborhood': ((-29, 0),)})

    def test_basic67(self):
        """basic 2d induced neighborhood test"""
        def kernel(a):
            cumul = 0
            for i in range(-5, 1):
                for j in range(-10, 1):
                    cumul += a[i, j]
            return cumul / (10 * 5)
        a = np.arange(10. * 20.).reshape(10, 20)
        self.check(kernel, a, options={'neighborhood': ((-5, 0), (-10, 0),)})

    def test_basic67b(self):
        """basic 2d induced 1D neighborhood"""
        def kernel(a):
            cumul = 0
            for j in range(-10, 1):
                cumul += a[0, j]
            return cumul / (10 * 5)
        a = np.arange(10. * 20.).reshape(10, 20)
        self.check(
            kernel,
            a,
            options={
                'neighborhood': (
                    (-10,
                     0),
                )},
            expected_exception=[
                TypingError,
                ValueError])

    # Should this work or is it UB? a[i, 0] is out of neighborhood?
    def test_basic68(self):
        """basic 2d one induced, one cost neighborhood test"""
        def kernel(a):
            cumul = 0
            for i in range(-5, 1):
                for j in range(-10, 1):
                    cumul += a[i, 0]
            return cumul / (10 * 5)
        a = np.arange(10. * 20.).reshape(10, 20)
        self.check(kernel, a, options={'neighborhood': ((-5, 0), (-10, 0),)})

    # Should this work or is it UB? a[0, 0] is out of neighborhood?
    def test_basic69(self):
        """basic 2d two cost neighborhood test"""
        def kernel(a):
            cumul = 0
            for i in range(-5, 1):
                for j in range(-10, 1):
                    cumul += a[0, 0]
            return cumul / (10 * 5)
        a = np.arange(10. * 20.).reshape(10, 20)
        self.check(kernel, a, options={'neighborhood': ((-5, 0), (-10, 0),)})

    def test_basic70(self):
        """neighborhood adding complexity"""
        def kernel(a):
            cumul = 0
            zz = 12.
            for i in range(-5, 1):
                t = zz + i
                for j in range(-10, 1):
                    cumul += a[i, j] + t
            return cumul / (10 * 5)
        a = np.arange(10. * 20.).reshape(10, 20)
        self.check(kernel, a, options={'neighborhood': ((-5, 0), (-10, 0),)})

    def test_basic71(self):
        """neighborhood, type change"""
        def kernel(a):
            cumul = 0
            for i in range(-29, 1):
                k = 0.
                if i > -15:
                    k = 1j
                cumul += a[i] + k
            return cumul / 30
        a = np.arange(60.)
        self.check(kernel, a, options={'neighborhood': ((-29, 0),)})

    def test_basic72(self):
        """neighborhood, narrower range than specified"""
        def kernel(a):
            cumul = 0
            for i in range(-19, -3):
                cumul += a[i]
            return cumul / 30
        a = np.arange(60.)
        self.check(kernel, a, options={'neighborhood': ((-29, 0),)})

    def test_basic73(self):
        """neighborhood, +ve range"""
        def kernel(a):
            cumul = 0
            for i in range(5, 11):
                cumul += a[i]
            return cumul / 30
        a = np.arange(60.)
        self.check(kernel, a, options={'neighborhood': ((5, 10),)})

    def test_basic73b(self):
        """neighborhood, -ve range"""
        def kernel(a):
            cumul = 0
            for i in range(-10, -4):
                cumul += a[i]
            return cumul / 30
        a = np.arange(60.)
        self.check(kernel, a, options={'neighborhood': ((-10, -5),)})

    def test_basic74(self):
        """neighborhood, -ve->+ve range span"""
        def kernel(a):
            cumul = 0
            for i in range(-5, 11):
                cumul += a[i]
            return cumul / 30
        a = np.arange(60.)
        self.check(kernel, a, options={'neighborhood': ((-5, 10),)})

    def test_basic75(self):
        """neighborhood, -ve->-ve range span"""
        def kernel(a):
            cumul = 0
            for i in range(-10, -1):
                cumul += a[i]
            return cumul / 30
        a = np.arange(60.)
        self.check(kernel, a, options={'neighborhood': ((-10, -2),)})

    def test_basic76(self):
        """neighborhood, mixed range span"""
        def kernel(a):
            cumul = 0
            zz = 12.
            for i in range(-3, 0):
                t = zz + i
                for j in range(-3, 4):
                    cumul += a[i, j] + t
            return cumul / (10 * 5)
        a = np.arange(10. * 20.).reshape(10, 20)
        self.check(kernel, a, options={'neighborhood': ((-3, -1), (-3, 3),)})

    def test_basic77(self):
        """ neighborhood, two args """
        def kernel(a, b):
            cumul = 0
            for i in range(-3, 1):
                for j in range(-3, 1):
                    cumul += a[i, j] + b[i, j]
            return cumul / (9.)
        a = np.arange(10. * 20.).reshape(10, 20)
        b = np.arange(10. * 20.).reshape(10, 20)
        self.check(kernel, a, b, options={'neighborhood': ((-3, 0), (-3, 0),)})

    def test_basic78(self):
        """ neighborhood, two args, -ve range, -ve range """
        def kernel(a, b):
            cumul = 0
            for i in range(-6, -2):
                for j in range(-7, -1):
                    cumul += a[i, j] + b[i, j]
            return cumul / (9.)
        a = np.arange(15. * 20.).reshape(15, 20)
        b = np.arange(15. * 20.).reshape(15, 20)
        self.check(
            kernel, a, b, options={
                'neighborhood': (
                    (-6, -3), (-7, -2),)})

    def test_basic78b(self):
        """ neighborhood, two args, -ve range, +ve range """
        def kernel(a, b):
            cumul = 0
            for i in range(-6, -2):
                for j in range(2, 10):
                    cumul += a[i, j] + b[i, j]
            return cumul / (9.)
        a = np.arange(15. * 20.).reshape(15, 20)
        b = np.arange(15. * 20.).reshape(15, 20)
        self.check(kernel, a, b, options={'neighborhood': ((-6, -3), (2, 9),)})

    def test_basic79(self):
        """ neighborhood, two incompatible args """
        def kernel(a, b):
            cumul = 0
            for i in range(-3, 1):
                for j in range(-3, 1):
                    cumul += a[i, j] + b[i, j]
            return cumul / (9.)
        a = np.arange(10. * 20.).reshape(10, 20)
        b = np.arange(10. * 20.).reshape(10, 10, 2)
        ex = self.exception_dict(
            pyStencil=ValueError,
            stencil=TypingError,
            parfor=TypingError,
            njit=TypingError)
        self.check(
            kernel, a, b, options={
                'neighborhood': (
                    (-3, 0), (-3, 0),)}, expected_exception=ex)

    def test_basic80(self):
        """ neighborhood, type change """
        def kernel(a, b):
            cumul = 0
            for i in range(-3, 1):
                for j in range(-3, 1):
                    cumul += a[i, j] + b
            return cumul / (9.)
        a = np.arange(10. * 20.).reshape(10, 20)
        b = 12.j
        self.check(kernel, a, b, options={'neighborhood': ((-3, 0), (-3, 0))})

    def test_basic81(self):
        """ neighborhood, dimensionally incompatible arrays """
        def kernel(a, b):
            cumul = 0
            for i in range(-3, 1):
                for j in range(-3, 1):
                    cumul += a[i, j] + b[i]
            return cumul / (9.)
        a = np.arange(10. * 20.).reshape(10, 20)
        b = a[0].copy()
        ex = self.exception_dict(
            pyStencil=ValueError,
            stencil=TypingError,
            parfor=AssertionError,
            njit=TypingError)
        self.check(
            kernel, a, b, options={
                'neighborhood': (
                    (-3, 0), (-3, 0))}, expected_exception=ex)

    def test_basic82(self):
        """ neighborhood, with standard_indexing"""
        def kernel(a, b):
            cumul = 0
            for i in range(-3, 1):
                for j in range(-3, 1):
                    cumul += a[i, j] + b[1, 3]
            return cumul / (9.)
        a = np.arange(10. * 20.).reshape(10, 20)
        b = a.copy()
        self.check(
            kernel, a, b, options={
                'neighborhood': (
                    (-3, 0), (-3, 0)), 'standard_indexing': 'b'})

    def test_basic83(self):
        """ neighborhood, with standard_indexing and cval"""
        def kernel(a, b):
            cumul = 0
            for i in range(-3, 1):
                for j in range(-3, 1):
                    cumul += a[i, j] + b[1, 3]
            return cumul / (9.)
        a = np.arange(10. * 20.).reshape(10, 20)
        b = a.copy()
        self.check(
            kernel, a, b, options={
                'neighborhood': (
                    (-3, 0), (-3, 0)), 'standard_indexing': 'b', 'cval': 1.5})

    def test_basic84(self):
        """ kernel calls njit """
        def kernel(a):
            return a[0, 0] + addone_njit(a[0, 1])

        a = np.arange(10. * 20.).reshape(10, 20)
        self.check(kernel, a)

    def test_basic85(self):
        """ kernel calls njit(parallel=True)"""
        def kernel(a):
            return a[0, 0] + addone_pjit(a[0, 1])

        a = np.arange(10. * 20.).reshape(10, 20)
        self.check(kernel, a)

    # njit/parfors fail correctly, but the error message isn't very informative
    def test_basic86(self):
        """ bad kwarg """
        def kernel(a):
            return a[0, 0]

        a = np.arange(10. * 20.).reshape(10, 20)
        self.check(kernel, a, options={'bad': 10},
                   expected_exception=[ValueError, TypingError])

    def test_basic87(self):
        """ reserved arg name in use """
        def kernel(__sentinel__):
            return __sentinel__[0, 0]

        a = np.arange(10. * 20.).reshape(10, 20)
        self.check(kernel, a)

    def test_basic88(self):
        """ use of reserved word """
        def kernel(a, out):
            return out * a[0, 1]
        a = np.arange(12.).reshape(3, 4)
        ex = self.exception_dict(
            pyStencil=ValueError,
            stencil=NumbaValueError,
            parfor=ValueError,
            njit=NumbaValueError)
        self.check(
            kernel,
            a,
            1.0,
            options={},
            expected_exception=ex)

    def test_basic89(self):
        """ basic multiple return"""
        def kernel(a):
            if a[0, 1] > 10:
                return 10.
            elif a[0, 3] < 8:
                return a[0, 0]
            else:
                return 7.

        a = np.arange(10. * 20.).reshape(10, 20)
        self.check(kernel, a)

    def test_basic90(self):
        """ neighborhood, with standard_indexing and cval, multiple returns"""
        def kernel(a, b):
            cumul = 0
            for i in range(-3, 1):
                for j in range(-3, 1):
                    cumul += a[i, j] + b[1, 3]
            res = cumul / (9.)
            if res > 200.0:
                return res + 1.0
            else:
                return res
        a = np.arange(10. * 20.).reshape(10, 20)
        b = a.copy()
        self.check(
            kernel, a, b, options={
                'neighborhood': (
                    (-3, 0), (-3, 0)), 'standard_indexing': 'b', 'cval': 1.5})

    def test_basic91(self):
        """ Issue #3454, const(int) == const(int) evaluating incorrectly. """
        def kernel(a):
            b = 0
            if(2 == 0):
                b = 2
            return a[0, 0] + b

        a = np.arange(10. * 20.).reshape(10, 20)
        self.check(kernel, a)

    def test_basic92(self):
        """ Issue #3497, bool return type evaluating incorrectly. """
        def kernel(a):
            return (a[-1, -1] ^ a[-1, 0] ^ a[-1, 1] ^
                    a[0, -1]  ^ a[0, 0]  ^ a[0, 1]  ^
                    a[1, -1]  ^ a[1, 0]  ^ a[1, 1])

        A = np.array(np.arange(20) % 2).reshape(4, 5).astype(np.bool_)
        self.check(kernel, A)

    def test_basic93(self):
        """ Issue #3497, bool return type evaluating incorrectly. """
        def kernel(a):
            return (a[-1, -1] ^ a[-1, 0] ^ a[-1, 1] ^
                    a[0, -1]  ^ a[0, 0]  ^ a[0, 1]  ^
                    a[1, -1]  ^ a[1, 0]  ^ a[1, 1])

        A = np.array(np.arange(20) % 2).reshape(4, 5).astype(np.bool_)
        self.check(kernel, A, options={'cval': True})

    def test_basic94(self):
        """ Issue #3528. Support for slices. """
        def kernel(a):
            return np.median(a[-1:2, -1:2])
        a = np.arange(20, dtype=np.uint32).reshape(4, 5)
        self.check(kernel, a, options={'neighborhood': ((-1, 1), (-1, 1),)})

    @unittest.skip("not yet supported")
    def test_basic95(self):
        """ Slice, calculate neighborhood. """
        def kernel(a):
            return np.median(a[-1:2, -3:4])
        a = np.arange(20, dtype=np.uint32).reshape(4, 5)
        self.check(kernel, a)

    def test_basic96(self):
        """ 1D slice. """
        def kernel(a):
            return np.median(a[-1:2])
        a = np.arange(20, dtype=np.uint32)
        self.check(kernel, a, options={'neighborhood': ((-1, 1),)})

    @unittest.skip("not yet supported")
    def test_basic97(self):
        """ 2D slice and index. """
        def kernel(a):
            return np.median(a[-1:2, 3])
        a = np.arange(20, dtype=np.uint32).reshape(4, 5)
        self.check(kernel, a)

    def test_basic98(self):
        """ Test issue #7286 where the cval is a np attr/string-based numerical
        constant"""
        for cval in (np.nan, np.inf, -np.inf, float('inf'), -float('inf')):
            def kernel(a):
                return a[0, 0]
            a = np.arange(6.).reshape((2, 3))
            self.check(kernel, a, options={'neighborhood': ((-1, 1), (-1, 1),),
                       'cval':cval})


if __name__ == "__main__":
    unittest.main()
