"""
Implement the random and np.random module functions.
"""


import math
import os
import random

import numpy as np

from llvmlite import ir

from numba.core.extending import overload, register_jitable
from numba.core.imputils import (Registry, impl_ret_untracked,
                                    impl_ret_new_ref)
from numba.core.typing import signature
from numba import _helperlib
from numba.core import types, utils, cgutils
from numba.np import arrayobj
from numba.core.overload_glue import glue_lowering
from numba.core.errors import NumbaTypeError


POST_PY38 = utils.PYVERSION >= (3, 8)


registry = Registry('randomimpl')
lower = registry.lower

int32_t = ir.IntType(32)
int64_t = ir.IntType(64)
def const_int(x):
    return ir.Constant(int32_t, x)
double = ir.DoubleType()

N = 624
N_const = ir.Constant(int32_t, N)


# This is the same struct as rnd_state_t in _random.c.
rnd_state_t = ir.LiteralStructType([
    # index
    int32_t,
    # mt[N]
    ir.ArrayType(int32_t, N),
    # has_gauss
    int32_t,
    # gauss
    double,
    # is_initialized
    int32_t,
    ])
rnd_state_ptr_t = ir.PointerType(rnd_state_t)

def get_state_ptr(context, builder, name):
    """
    Get a pointer to the given thread-local random state
    (depending on *name*: "py" or "np").
    If the state isn't initialized, it is lazily initialized with
    system entropy.
    """
    assert name in ('py', 'np', 'internal')
    func_name = "numba_get_%s_random_state" % name
    fnty = ir.FunctionType(rnd_state_ptr_t, ())
    fn = cgutils.get_or_insert_function(builder.module, fnty, func_name)
    # These two attributes allow LLVM to hoist the function call
    # outside of loops.
    fn.attributes.add('readnone')
    fn.attributes.add('nounwind')
    return builder.call(fn, ())

def get_py_state_ptr(context, builder):
    """
    Get a pointer to the thread-local Python random state.
    """
    return get_state_ptr(context, builder, 'py')

def get_np_state_ptr(context, builder):
    """
    Get a pointer to the thread-local Numpy random state.
    """
    return get_state_ptr(context, builder, 'np')

def get_internal_state_ptr(context, builder):
    """
    Get a pointer to the thread-local internal random state.
    """
    return get_state_ptr(context, builder, 'internal')

# Accessors
def get_index_ptr(builder, state_ptr):
    return cgutils.gep_inbounds(builder, state_ptr, 0, 0)

def get_array_ptr(builder, state_ptr):
    return cgutils.gep_inbounds(builder, state_ptr, 0, 1)

def get_has_gauss_ptr(builder, state_ptr):
    return cgutils.gep_inbounds(builder, state_ptr, 0, 2)

def get_gauss_ptr(builder, state_ptr):
    return cgutils.gep_inbounds(builder, state_ptr, 0, 3)

def get_rnd_shuffle(builder):
    """
    Get the internal function to shuffle the MT taste.
    """
    fnty = ir.FunctionType(ir.VoidType(), (rnd_state_ptr_t,))
    fn = cgutils.get_or_insert_function(builder.function.module, fnty,
                                        "numba_rnd_shuffle")
    fn.args[0].add_attribute("nocapture")
    return fn


def get_next_int32(context, builder, state_ptr):
    """
    Get the next int32 generated by the PRNG at *state_ptr*.
    """
    idxptr = get_index_ptr(builder, state_ptr)
    idx = builder.load(idxptr)
    need_reshuffle = builder.icmp_unsigned('>=', idx, N_const)
    with cgutils.if_unlikely(builder, need_reshuffle):
        fn = get_rnd_shuffle(builder)
        builder.call(fn, (state_ptr,))
        builder.store(const_int(0), idxptr)
    idx = builder.load(idxptr)
    array_ptr = get_array_ptr(builder, state_ptr)
    y = builder.load(cgutils.gep_inbounds(builder, array_ptr, 0, idx))
    idx = builder.add(idx, const_int(1))
    builder.store(idx, idxptr)
    # Tempering
    y = builder.xor(y, builder.lshr(y, const_int(11)))
    y = builder.xor(y, builder.and_(builder.shl(y, const_int(7)),
                                    const_int(0x9d2c5680)))
    y = builder.xor(y, builder.and_(builder.shl(y, const_int(15)),
                                    const_int(0xefc60000)))
    y = builder.xor(y, builder.lshr(y, const_int(18)))
    return y

def get_next_double(context, builder, state_ptr):
    """
    Get the next double generated by the PRNG at *state_ptr*.
    """
    # a = rk_random(state) >> 5, b = rk_random(state) >> 6;
    a = builder.lshr(get_next_int32(context, builder, state_ptr), const_int(5))
    b = builder.lshr(get_next_int32(context, builder, state_ptr), const_int(6))

    # return (a * 67108864.0 + b) / 9007199254740992.0;
    a = builder.uitofp(a, double)
    b = builder.uitofp(b, double)
    return builder.fdiv(
        builder.fadd(b, builder.fmul(a, ir.Constant(double, 67108864.0))),
        ir.Constant(double, 9007199254740992.0))

def get_next_int(context, builder, state_ptr, nbits, is_numpy):
    """
    Get the next integer with width *nbits*.
    """
    c32 = ir.Constant(nbits.type, 32)
    def get_shifted_int(nbits):
        shift = builder.sub(c32, nbits)
        y = get_next_int32(context, builder, state_ptr)
        if is_numpy:
            # Use the last N bits, to match np.random
            mask = builder.not_(ir.Constant(y.type, 0))
            mask = builder.lshr(mask, builder.zext(shift, y.type))
            return builder.and_(y, mask)
        else:
            # Use the first N bits, to match CPython random
            return builder.lshr(y, builder.zext(shift, y.type))

    ret = cgutils.alloca_once_value(builder, ir.Constant(int64_t, 0))

    is_32b = builder.icmp_unsigned('<=', nbits, c32)
    with builder.if_else(is_32b) as (ifsmall, iflarge):
        with ifsmall:
            low = get_shifted_int(nbits)
            builder.store(builder.zext(low, int64_t), ret)
        with iflarge:
            # XXX This assumes nbits <= 64
            if is_numpy:
                # Get the high bits first to match np.random
                high = get_shifted_int(builder.sub(nbits, c32))
            low = get_next_int32(context, builder, state_ptr)
            if not is_numpy:
                # Get the high bits second to match CPython random
                high = get_shifted_int(builder.sub(nbits, c32))
            total = builder.add(
                builder.zext(low, int64_t),
                builder.shl(builder.zext(high, int64_t), ir.Constant(int64_t, 32)))
            builder.store(total, ret)

    return builder.load(ret)


def _fill_defaults(context, builder, sig, args, defaults):
    """
    Assuming a homogeneous signature (same type for result and all arguments),
    fill in the *defaults* if missing from the arguments.
    """
    ty = sig.return_type
    llty = context.get_data_type(ty)
    args = tuple(args) + tuple(ir.Constant(llty, d) for d in defaults[len(args):])
    sig = signature(*(ty,) * (len(args) + 1))
    return sig, args


@glue_lowering("random.seed", types.uint32)
def seed_impl(context, builder, sig, args):
    res =  _seed_impl(context, builder, sig, args, get_state_ptr(context,
                                                                 builder, "py"))
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.seed", types.uint32)
def seed_impl(context, builder, sig, args):
    res = _seed_impl(context, builder, sig, args, get_state_ptr(context,
                                                               builder, "np"))
    return impl_ret_untracked(context, builder, sig.return_type, res)

def _seed_impl(context, builder, sig, args, state_ptr):
    seed_value, = args
    fnty = ir.FunctionType(ir.VoidType(), (rnd_state_ptr_t, int32_t))
    fn = cgutils.get_or_insert_function(builder.function.module, fnty,
                                        "numba_rnd_init")
    builder.call(fn, (state_ptr, seed_value))
    return context.get_constant(types.none, None)

@glue_lowering("random.random")
def random_impl(context, builder, sig, args):
    state_ptr = get_state_ptr(context, builder, "py")
    res = get_next_double(context, builder, state_ptr)
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.random")
@glue_lowering("np.random.random_sample")
@glue_lowering("np.random.sample")
@glue_lowering("np.random.ranf")
def random_impl(context, builder, sig, args):
    state_ptr = get_state_ptr(context, builder, "np")
    res = get_next_double(context, builder, state_ptr)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("random.gauss", types.Float, types.Float)
@glue_lowering("random.normalvariate", types.Float, types.Float)
def gauss_impl(context, builder, sig, args):
    res = _gauss_impl(context, builder, sig, args, "py")
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.standard_normal")
@glue_lowering("np.random.normal")
@glue_lowering("np.random.normal", types.Float)
@glue_lowering("np.random.normal", types.Float, types.Float)
def np_gauss_impl(context, builder, sig, args):
    sig, args = _fill_defaults(context, builder, sig, args, (0.0, 1.0))
    res = _gauss_impl(context, builder, sig, args, "np")
    return impl_ret_untracked(context, builder, sig.return_type, res)


def _gauss_pair_impl(_random):
    def compute_gauss_pair():
        """
        Compute a pair of numbers on the normal distribution.
        """
        while True:
            x1 = 2.0 * _random() - 1.0
            x2 = 2.0 * _random() - 1.0
            r2 = x1*x1 + x2*x2
            if r2 < 1.0 and r2 != 0.0:
                break

        # Box-Muller transform
        f = math.sqrt(-2.0 * math.log(r2) / r2)
        return f * x1, f * x2
    return compute_gauss_pair

def _gauss_impl(context, builder, sig, args, state):
    # The type for all computations (either float or double)
    ty = sig.return_type
    llty = context.get_data_type(ty)

    state_ptr = get_state_ptr(context, builder, state)
    _random = {"py": random.random,
               "np": np.random.random}[state]

    ret = cgutils.alloca_once(builder, llty, name="result")

    gauss_ptr = get_gauss_ptr(builder, state_ptr)
    has_gauss_ptr = get_has_gauss_ptr(builder, state_ptr)
    has_gauss = cgutils.is_true(builder, builder.load(has_gauss_ptr))
    with builder.if_else(has_gauss) as (then, otherwise):
        with then:
            # if has_gauss: return it
            builder.store(builder.load(gauss_ptr), ret)
            builder.store(const_int(0), has_gauss_ptr)
        with otherwise:
            # if not has_gauss: compute a pair of numbers using the Box-Muller
            # transform; keep one and return the other
            pair = context.compile_internal(builder,
                                            _gauss_pair_impl(_random),
                                            signature(types.UniTuple(ty, 2)),
                                            ())

            first, second = cgutils.unpack_tuple(builder, pair, 2)
            builder.store(first, gauss_ptr)
            builder.store(second, ret)
            builder.store(const_int(1), has_gauss_ptr)

    mu, sigma = args
    return builder.fadd(mu,
                        builder.fmul(sigma, builder.load(ret)))

@glue_lowering("random.getrandbits", types.Integer)
def getrandbits_impl(context, builder, sig, args):
    nbits, = args
    too_large = builder.icmp_unsigned(">=", nbits, const_int(65))
    too_small = builder.icmp_unsigned("==", nbits, const_int(0))
    with cgutils.if_unlikely(builder, builder.or_(too_large, too_small)):
        msg = "getrandbits() limited to 64 bits"
        context.call_conv.return_user_exc(builder, OverflowError, (msg,))
    state_ptr = get_state_ptr(context, builder, "py")
    res = get_next_int(context, builder, state_ptr, nbits, False)
    return impl_ret_untracked(context, builder, sig.return_type, res)


def _randrange_impl(context, builder, start, stop, step, state):
    state_ptr = get_state_ptr(context, builder, state)
    ty = stop.type
    zero = ir.Constant(ty, 0)
    one = ir.Constant(ty, 1)
    nptr = cgutils.alloca_once(builder, ty, name="n")
    # n = stop - start
    builder.store(builder.sub(stop, start), nptr)

    with builder.if_then(builder.icmp_signed('<', step, zero)):
        # n = (n + step + 1) // step
        w = builder.add(builder.add(builder.load(nptr), step), one)
        n = builder.sdiv(w, step)
        builder.store(n, nptr)
    with builder.if_then(builder.icmp_signed('>', step, one)):
        # n = (n + step - 1) // step
        w = builder.sub(builder.add(builder.load(nptr), step), one)
        n = builder.sdiv(w, step)
        builder.store(n, nptr)

    n = builder.load(nptr)
    with cgutils.if_unlikely(builder, builder.icmp_signed('<=', n, zero)):
        # n <= 0
        msg = "empty range for randrange()"
        context.call_conv.return_user_exc(builder, ValueError, (msg,))

    fnty = ir.FunctionType(ty, [ty, cgutils.true_bit.type])
    fn = cgutils.get_or_insert_function(builder.function.module, fnty,
                                        "llvm.ctlz.%s" % ty)
    # Since the upper bound is exclusive, we need to subtract one before
    # calculating the number of bits. This leads to a special case when
    # n == 1; there's only one possible result, so we don't need bits from
    # the PRNG. This case is handled separately towards the end of this
    # function. CPython's implementation is simpler and just runs another
    # iteration of the while loop when the resulting number is too large
    # instead of subtracting one, to avoid needing to handle a special
    # case. Thus, we only perform this subtraction for the NumPy case.
    nm1 = builder.sub(n, one) if state == "np" else n
    nbits = builder.trunc(builder.call(fn, [nm1, cgutils.true_bit]), int32_t)
    nbits = builder.sub(ir.Constant(int32_t, ty.width), nbits)

    rptr = cgutils.alloca_once(builder, ty, name="r")

    def get_num():
        bbwhile = builder.append_basic_block("while")
        bbend = builder.append_basic_block("while.end")
        builder.branch(bbwhile)

        builder.position_at_end(bbwhile)
        r = get_next_int(context, builder, state_ptr, nbits, state == "np")
        r = builder.trunc(r, ty)
        too_large = builder.icmp_signed('>=', r, n)
        builder.cbranch(too_large, bbwhile, bbend)

        builder.position_at_end(bbend)
        builder.store(r, rptr)

    if state == "np":
        # Handle n == 1 case, per previous comment.
        with builder.if_else(builder.icmp_signed('==', n, one)) as (is_one, is_not_one):
            with is_one:
                builder.store(zero, rptr)
            with is_not_one:
                get_num()
    else:
        get_num()

    return builder.add(start, builder.mul(builder.load(rptr), step))


@glue_lowering("random.randrange", types.Integer)
def randrange_impl_1(context, builder, sig, args):
    stop, = args
    start = ir.Constant(stop.type, 0)
    step = ir.Constant(stop.type, 1)
    res = _randrange_impl(context, builder, start, stop, step, "py")
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("random.randrange", types.Integer, types.Integer)
def randrange_impl_2(context, builder, sig, args):
    start, stop = args
    step = ir.Constant(start.type, 1)
    res = _randrange_impl(context, builder, start, stop, step, "py")
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("random.randrange", types.Integer,
               types.Integer, types.Integer)
def randrange_impl_3(context, builder, sig, args):
    start, stop, step = args
    res = _randrange_impl(context, builder, start, stop, step, "py")
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("random.randint", types.Integer, types.Integer)
def randint_impl_1(context, builder, sig, args):
    start, stop = args
    step = ir.Constant(start.type, 1)
    stop = builder.add(stop, step)
    res = _randrange_impl(context, builder, start, stop, step, "py")
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.randint", types.Integer)
def randint_impl_2(context, builder, sig, args):
    stop, = args
    start = ir.Constant(stop.type, 0)
    step = ir.Constant(stop.type, 1)
    res = _randrange_impl(context, builder, start, stop, step, "np")
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.randint", types.Integer, types.Integer)
def randrange_impl_2(context, builder, sig, args):
    start, stop = args
    step = ir.Constant(start.type, 1)
    res = _randrange_impl(context, builder, start, stop, step, "np")
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("random.uniform", types.Float, types.Float)
def uniform_impl(context, builder, sig, args):
    res = uniform_impl(context, builder, sig, args, "py")
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.uniform", types.Float, types.Float)
def uniform_impl(context, builder, sig, args):
    res = uniform_impl(context, builder, sig, args, "np")
    return impl_ret_untracked(context, builder, sig.return_type, res)

def uniform_impl(context, builder, sig, args, state):
    state_ptr = get_state_ptr(context, builder, state)
    a, b = args
    width = builder.fsub(b, a)
    r = get_next_double(context, builder, state_ptr)
    return builder.fadd(a, builder.fmul(width, r))

@glue_lowering("random.triangular", types.Float, types.Float)
def triangular_impl_2(context, builder, sig, args):
    fltty = sig.return_type
    low, high = args
    state_ptr = get_state_ptr(context, builder, "py")
    randval = get_next_double(context, builder, state_ptr)

    def triangular_impl_2(randval, low, high):
        u = randval
        c = 0.5
        if u > c:
            u = 1.0 - u
            low, high = high, low
        return low + (high - low) * math.sqrt(u * c)

    res = context.compile_internal(builder, triangular_impl_2,
                                    signature(*(fltty,) * 4),
                                    (randval, low, high))
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("random.triangular", types.Float, types.Float, types.Float)
def triangular_impl_3(context, builder, sig, args):
    low, high, mode = args
    res = _triangular_impl_3(context, builder, sig, low, high, mode, "py")
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.triangular", types.Float,
           types.Float, types.Float)
def triangular_impl_3(context, builder, sig, args):
    low, mode, high = args
    res = _triangular_impl_3(context, builder, sig, low, high, mode, "np")
    return impl_ret_untracked(context, builder, sig.return_type, res)

def _triangular_impl_3(context, builder, sig, low, high, mode, state):
    fltty = sig.return_type
    state_ptr = get_state_ptr(context, builder, state)
    randval = get_next_double(context, builder, state_ptr)

    def triangular_impl_3(randval, low, high, mode):
        if high == low:
            return low
        u = randval
        c = (mode - low) / (high - low)
        if u > c:
            u = 1.0 - u
            c = 1.0 - c
            low, high = high, low
        return low + (high - low) * math.sqrt(u * c)

    return context.compile_internal(builder, triangular_impl_3,
                                    signature(*(fltty,) * 5),
                                    (randval, low, high, mode))


@glue_lowering("random.gammavariate", types.Float, types.Float)
def gammavariate_impl(context, builder, sig, args):
    res = _gammavariate_impl(context, builder, sig, args, random.random)
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.standard_gamma", types.Float)
@glue_lowering("np.random.gamma", types.Float)
@glue_lowering("np.random.gamma", types.Float, types.Float)
def gammavariate_impl(context, builder, sig, args):
    sig, args = _fill_defaults(context, builder, sig, args, (None, 1.0))
    res = _gammavariate_impl(context, builder, sig, args, np.random.random)
    return impl_ret_untracked(context, builder, sig.return_type, res)

def _gammavariate_impl(context, builder, sig, args, _random):
    _exp = math.exp
    _log = math.log
    _sqrt = math.sqrt
    _e = math.e

    TWOPI = 2.0 * math.pi
    LOG4 = _log(4.0)
    SG_MAGICCONST = 1.0 + _log(4.5)

    def gammavariate_impl(alpha, beta):
        """Gamma distribution.  Taken from CPython.
        """
        # alpha > 0, beta > 0, mean is alpha*beta, variance is alpha*beta**2

        # Warning: a few older sources define the gamma distribution in terms
        # of alpha > -1.0
        if alpha <= 0.0 or beta <= 0.0:
            raise ValueError('gammavariate: alpha and beta must be > 0.0')

        if alpha > 1.0:
            # Uses R.C.H. Cheng, "The generation of Gamma
            # variables with non-integral shape parameters",
            # Applied Statistics, (1977), 26, No. 1, p71-74
            ainv = _sqrt(2.0 * alpha - 1.0)
            bbb = alpha - LOG4
            ccc = alpha + ainv

            while 1:
                u1 = _random()
                if not 1e-7 < u1 < .9999999:
                    continue
                u2 = 1.0 - _random()
                v = _log(u1/(1.0-u1))/ainv
                x = alpha*_exp(v)
                z = u1*u1*u2
                r = bbb+ccc*v-x
                if r + SG_MAGICCONST - 4.5*z >= 0.0 or r >= _log(z):
                    return x * beta

        elif alpha == 1.0:
            # expovariate(1)

            if POST_PY38:
                # Adjust due to cpython
                # commit 63d152232e1742660f481c04a811f824b91f6790
                return -_log(1.0 - _random()) * beta
            else:
                u = _random()
                while u <= 1e-7:
                    u = _random()
                return -_log(u) * beta

        else:   # alpha is between 0 and 1 (exclusive)
            # Uses ALGORITHM GS of Statistical Computing - Kennedy & Gentle
            while 1:
                u = _random()
                b = (_e + alpha)/_e
                p = b*u
                if p <= 1.0:
                    x = p ** (1.0/alpha)
                else:
                    x = -_log((b-p)/alpha)
                u1 = _random()
                if p > 1.0:
                    if u1 <= x ** (alpha - 1.0):
                        break
                elif u1 <= _exp(-x):
                    break
            return x * beta

    return context.compile_internal(builder, gammavariate_impl,
                                    sig, args)


@glue_lowering("random.betavariate", types.Float, types.Float)
def betavariate_impl(context, builder, sig, args):
    res = _betavariate_impl(context, builder, sig, args,
                             random.gammavariate)
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.beta",
           types.Float, types.Float)
def betavariate_impl(context, builder, sig, args):
    res = _betavariate_impl(context, builder, sig, args,
                             np.random.gamma)
    return impl_ret_untracked(context, builder, sig.return_type, res)

def _betavariate_impl(context, builder, sig, args, gamma):

    def betavariate_impl(alpha, beta):
        """Beta distribution.  Taken from CPython.
        """
        # This version due to Janne Sinkkonen, and matches all the std
        # texts (e.g., Knuth Vol 2 Ed 3 pg 134 "the beta distribution").
        y = gamma(alpha, 1.)
        if y == 0.0:
            return 0.0
        else:
            return y / (y + gamma(beta, 1.))

    return context.compile_internal(builder, betavariate_impl,
                                    sig, args)


@glue_lowering("random.expovariate", types.Float)
def expovariate_impl(context, builder, sig, args):
    _random = random.random
    _log = math.log

    def expovariate_impl(lambd):
        """Exponential distribution.  Taken from CPython.
        """
        # lambd: rate lambd = 1/mean
        # ('lambda' is a Python reserved word)

        # we use 1-random() instead of random() to preclude the
        # possibility of taking the log of zero.
        return -_log(1.0 - _random()) / lambd

    res = context.compile_internal(builder, expovariate_impl,
                                    sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.exponential", types.Float)
def exponential_impl(context, builder, sig, args):
    _random = np.random.random
    _log = math.log

    def exponential_impl(scale):
        return -_log(1.0 - _random()) * scale

    res = context.compile_internal(builder, exponential_impl,
                                    sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.standard_exponential")
@glue_lowering("np.random.exponential")
def exponential_impl(context, builder, sig, args):
    _random = np.random.random
    _log = math.log

    def exponential_impl():
        return -_log(1.0 - _random())

    res = context.compile_internal(builder, exponential_impl,
                                    sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.lognormal")
@glue_lowering("np.random.lognormal", types.Float)
@glue_lowering("np.random.lognormal", types.Float, types.Float)
def np_lognormal_impl(context, builder, sig, args):
    sig, args = _fill_defaults(context, builder, sig, args, (0.0, 1.0))
    res = _lognormvariate_impl(context, builder, sig, args,
                                np.random.normal)
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("random.lognormvariate",
           types.Float, types.Float)
def lognormvariate_impl(context, builder, sig, args):
    res = _lognormvariate_impl(context, builder, sig, args, random.gauss)
    return impl_ret_untracked(context, builder, sig.return_type, res)

def _lognormvariate_impl(context, builder, sig, args, _gauss):
    _exp = math.exp

    def lognormvariate_impl(mu, sigma):
        return _exp(_gauss(mu, sigma))

    return context.compile_internal(builder, lognormvariate_impl,
                                    sig, args)


@glue_lowering("random.paretovariate", types.Float)
def paretovariate_impl(context, builder, sig, args):
    _random = random.random

    def paretovariate_impl(alpha):
        """Pareto distribution.  Taken from CPython."""
        # Jain, pg. 495
        u = 1.0 - _random()
        return 1.0 / u ** (1.0/alpha)

    res = context.compile_internal(builder, paretovariate_impl,
                                    sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.pareto", types.Float)
def pareto_impl(context, builder, sig, args):
    _random = np.random.random

    def pareto_impl(alpha):
        # Same as paretovariate() - 1.
        u = 1.0 - _random()
        return 1.0 / u ** (1.0/alpha) - 1

    res = context.compile_internal(builder, pareto_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("random.weibullvariate",
           types.Float, types.Float)
def weibullvariate_impl(context, builder, sig, args):
    _random = random.random
    _log = math.log

    def weibullvariate_impl(alpha, beta):
        """Weibull distribution.  Taken from CPython."""
        # Jain, pg. 499; bug fix courtesy Bill Arms
        u = 1.0 - _random()
        return alpha * (-_log(u)) ** (1.0/beta)

    res = context.compile_internal(builder, weibullvariate_impl,
                                    sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.weibull", types.Float)
def weibull_impl(context, builder, sig, args):
    _random = np.random.random
    _log = math.log

    def weibull_impl(beta):
        # Same as weibullvariate(1.0, beta)
        u = 1.0 - _random()
        return (-_log(u)) ** (1.0/beta)

    res = context.compile_internal(builder, weibull_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("random.vonmisesvariate",
           types.Float, types.Float)
def vonmisesvariate_impl(context, builder, sig, args):
    res = _vonmisesvariate_impl(context, builder, sig, args, random.random)
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.vonmises",
           types.Float, types.Float)
def vonmisesvariate_impl(context, builder, sig, args):
    res = _vonmisesvariate_impl(context, builder, sig, args, np.random.random)
    return impl_ret_untracked(context, builder, sig.return_type, res)

def _vonmisesvariate_impl(context, builder, sig, args, _random):
    _exp = math.exp
    _sqrt = math.sqrt
    _cos = math.cos
    _acos = math.acos
    _pi = math.pi
    TWOPI = 2.0 * _pi

    def vonmisesvariate_impl(mu, kappa):
        """Circular data distribution.  Taken from CPython.
        Note the algorithm in Python 2.6 and Numpy is different:
        http://bugs.python.org/issue17141
        """
        # mu:    mean angle (in radians between 0 and 2*pi)
        # kappa: concentration parameter kappa (>= 0)
        # if kappa = 0 generate uniform random angle

        # Based upon an algorithm published in: Fisher, N.I.,
        # "Statistical Analysis of Circular Data", Cambridge
        # University Press, 1993.

        # Thanks to Magnus Kessler for a correction to the
        # implementation of step 4.
        if kappa <= 1e-6:
            return TWOPI * _random()

        s = 0.5 / kappa
        r = s + _sqrt(1.0 + s * s)

        while 1:
            u1 = _random()
            z = _cos(_pi * u1)

            d = z / (r + z)
            u2 = _random()
            if u2 < 1.0 - d * d or u2 <= (1.0 - d) * _exp(d):
                break

        q = 1.0 / r
        f = (q + z) / (1.0 + q * z)
        u3 = _random()
        if u3 > 0.5:
            theta = (mu + _acos(f)) % TWOPI
        else:
            theta = (mu - _acos(f)) % TWOPI

        return theta

    return context.compile_internal(builder, vonmisesvariate_impl,
                                    sig, args)


@glue_lowering("np.random.binomial", types.Integer, types.Float)
def binomial_impl(context, builder, sig, args):
    intty = sig.return_type
    _random = np.random.random

    def binomial_impl(n, p):
        """
        Binomial distribution.  Numpy's variant of the BINV algorithm
        is used.
        (Numpy uses BTPE for n*p >= 30, though)
        """
        if n < 0:
            raise ValueError("binomial(): n <= 0")
        if not (0.0 <= p <= 1.0):
            raise ValueError("binomial(): p outside of [0, 1]")
        if p == 0.0:
            return 0
        if p == 1.0:
            return n

        flipped = p > 0.5
        if flipped:
            p = 1.0 - p
        q = 1.0 - p

        niters = 1
        qn = q ** n
        while qn <= 1e-308:
            # Underflow => split into several iterations
            # Note this is much slower than Numpy's BTPE
            niters <<= 2
            n >>= 2
            qn = q ** n
            assert n > 0

        np = n * p
        bound = min(n, np + 10.0 * math.sqrt(np * q + 1))

        finished = False
        total = 0
        while niters > 0:
            X = 0
            U = _random()
            px = qn
            while X <= bound:
                if U <= px:
                    total += n - X if flipped else X
                    niters -= 1
                    break
                U -= px
                X += 1
                px = ((n - X + 1) * p * px) / (X * q)

        return total

    res = context.compile_internal(builder, binomial_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.chisquare", types.Float)
def chisquare_impl(context, builder, sig, args):

    def chisquare_impl(df):
        return 2.0 * np.random.standard_gamma(df / 2.0)

    res = context.compile_internal(builder, chisquare_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.f", types.Float, types.Float)
def f_impl(context, builder, sig, args):

    def f_impl(num, denom):
        return ((np.random.chisquare(num) * denom) /
                (np.random.chisquare(denom) * num))

    res = context.compile_internal(builder, f_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.geometric", types.Float)
def geometric_impl(context, builder, sig, args):
    _random = np.random.random
    intty = sig.return_type

    def geometric_impl(p):
        # Numpy's algorithm.
        if p <= 0.0 or p > 1.0:
            raise ValueError("geometric(): p outside of (0, 1]")
        q = 1.0 - p
        if p >= 0.333333333333333333333333:
            X = intty(1)
            sum = prod = p
            U = _random()
            while U > sum:
                prod *= q
                sum += prod
                X += 1
            return X
        else:
            return math.ceil(math.log(1.0 - _random()) / math.log(q))

    res = context.compile_internal(builder, geometric_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.gumbel", types.Float, types.Float)
def gumbel_impl(context, builder, sig, args):
    _random = np.random.random
    _log = math.log

    def gumbel_impl(loc, scale):
        U = 1.0 - _random()
        return loc - scale * _log(-_log(U))

    res = context.compile_internal(builder, gumbel_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.hypergeometric", types.Integer,
           types.Integer, types.Integer)
def hypergeometric_impl(context, builder, sig, args):
    _random = np.random.random
    _floor = math.floor

    def hypergeometric_impl(ngood, nbad, nsamples):
        """Numpy's algorithm for hypergeometric()."""
        d1 = nbad + ngood - nsamples
        d2 = float(min(nbad, ngood))

        Y = d2
        K = nsamples
        while Y > 0.0 and K > 0:
            Y -= _floor(_random() + Y / (d1 + K))
            K -= 1
        Z = int(d2 - Y)
        if ngood > nbad:
            return nsamples - Z
        else:
            return Z

    res = context.compile_internal(builder, hypergeometric_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.laplace")
@glue_lowering("np.random.laplace", types.Float)
@glue_lowering("np.random.laplace", types.Float, types.Float)
def laplace_impl(context, builder, sig, args):
    _random = np.random.random
    _log = math.log

    def laplace_impl(loc, scale):
        U = _random()
        if U < 0.5:
            return loc + scale * _log(U + U)
        else:
            return loc - scale * _log(2.0 - U - U)

    sig, args = _fill_defaults(context, builder, sig, args, (0.0, 1.0))
    res = context.compile_internal(builder, laplace_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.logistic")
@glue_lowering("np.random.logistic", types.Float)
@glue_lowering("np.random.logistic", types.Float, types.Float)
def logistic_impl(context, builder, sig, args):
    _random = np.random.random
    _log = math.log

    def logistic_impl(loc, scale):
        U = _random()
        return loc + scale * _log(U / (1.0 - U))

    sig, args = _fill_defaults(context, builder, sig, args, (0.0, 1.0))
    res = context.compile_internal(builder, logistic_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)

@glue_lowering("np.random.logseries", types.Float)
def logseries_impl(context, builder, sig, args):
    intty = sig.return_type
    _random = np.random.random
    _log = math.log
    _exp = math.exp

    def logseries_impl(p):
        """Numpy's algorithm for logseries()."""
        if p <= 0.0 or p > 1.0:
            raise ValueError("logseries(): p outside of (0, 1]")
        r = _log(1.0 - p)

        while 1:
            V = _random()
            if V >= p:
                return 1
            U = _random()
            q = 1.0 - _exp(r * U)
            if V <= q * q:
                # XXX what if V == 0.0 ?
                return intty(1.0 + _log(V) / _log(q))
            elif V >= q:
                return 1
            else:
                return 2

    res = context.compile_internal(builder, logseries_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.negative_binomial", types.int64, types.Float)
def negative_binomial_impl(context, builder, sig, args):
    _gamma = np.random.gamma
    _poisson = np.random.poisson

    def negative_binomial_impl(n, p):
        if n <= 0:
            raise ValueError("negative_binomial(): n <= 0")
        if p < 0.0 or p > 1.0:
            raise ValueError("negative_binomial(): p outside of [0, 1]")
        Y = _gamma(n, (1.0 - p) / p)
        return _poisson(Y)

    res = context.compile_internal(builder, negative_binomial_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.poisson")
@glue_lowering("np.random.poisson", types.Float)
def poisson_impl(context, builder, sig, args):
    state_ptr = get_np_state_ptr(context, builder)

    retptr = cgutils.alloca_once(builder, int64_t, name="ret")
    bbcont = builder.append_basic_block("bbcont")
    bbend = builder.append_basic_block("bbend")

    if len(args) == 1:
        lam, = args
        big_lam = builder.fcmp_ordered('>=', lam, ir.Constant(double, 10.0))
        with builder.if_then(big_lam):
            # For lambda >= 10.0, we switch to a more accurate
            # algorithm (see _random.c).
            fnty = ir.FunctionType(int64_t, (rnd_state_ptr_t, double))
            fn = cgutils.get_or_insert_function(builder.function.module, fnty,
                                                "numba_poisson_ptrs")
            ret = builder.call(fn, (state_ptr, lam))
            builder.store(ret, retptr)
            builder.branch(bbend)

    builder.branch(bbcont)
    builder.position_at_end(bbcont)

    _random = np.random.random
    _exp = math.exp

    def poisson_impl(lam):
        """Numpy's algorithm for poisson() on small *lam*.

        This method is invoked only if the parameter lambda of the
        distribution is small ( < 10 ). The algorithm used is described
        in "Knuth, D. 1969. 'Seminumerical Algorithms. The Art of
        Computer Programming' vol 2.
        """
        if lam < 0.0:
            raise ValueError("poisson(): lambda < 0")
        if lam == 0.0:
            return 0
        enlam = _exp(-lam)
        X = 0
        prod = 1.0
        while 1:
            U = _random()
            prod *= U
            if prod <= enlam:
                return X
            X += 1

    if len(args) == 0:
        sig = signature(sig.return_type, types.float64)
        args = (ir.Constant(double, 1.0),)

    ret = context.compile_internal(builder, poisson_impl, sig, args)
    builder.store(ret, retptr)
    builder.branch(bbend)
    builder.position_at_end(bbend)
    res = builder.load(retptr)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.power", types.Float)
def power_impl(context, builder, sig, args):

    def power_impl(a):
        if a <= 0.0:
            raise ValueError("power(): a <= 0")
        return math.pow(1 - math.exp(-np.random.standard_exponential()),
                        1./a)

    res = context.compile_internal(builder, power_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.rayleigh")
@glue_lowering("np.random.rayleigh", types.Float)
def rayleigh_impl(context, builder, sig, args):
    _random = np.random.random

    def rayleigh_impl(mode):
        if mode <= 0.0:
            raise ValueError("rayleigh(): mode <= 0")
        return mode * math.sqrt(-2.0 * math.log(1.0 - _random()))

    sig, args = _fill_defaults(context, builder, sig, args, (1.0,))
    res = context.compile_internal(builder, rayleigh_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.standard_cauchy")
def cauchy_impl(context, builder, sig, args):
    _gauss = np.random.standard_normal

    def cauchy_impl():
        return _gauss() / _gauss()

    res = context.compile_internal(builder, cauchy_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.standard_t", types.Float)
def standard_t_impl(context, builder, sig, args):

    def standard_t_impl(df):
        N = np.random.standard_normal()
        G = np.random.standard_gamma(df / 2.0)
        X = math.sqrt(df / 2.0) * N / math.sqrt(G)
        return X

    res = context.compile_internal(builder, standard_t_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.wald", types.Float, types.Float)
def wald_impl(context, builder, sig, args):

    def wald_impl(mean, scale):
        if mean <= 0.0:
            raise ValueError("wald(): mean <= 0")
        if scale <= 0.0:
            raise ValueError("wald(): scale <= 0")
        mu_2l = mean / (2.0 * scale)
        Y = np.random.standard_normal()
        Y = mean * Y * Y
        X = mean + mu_2l * (Y - math.sqrt(4 * scale * Y + Y * Y))
        U = np.random.random()
        if U <= mean / (mean + X):
            return X
        else:
            return mean * mean / X

    res = context.compile_internal(builder, wald_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)


@glue_lowering("np.random.zipf", types.Float)
def zipf_impl(context, builder, sig, args):
    _random = np.random.random
    intty = sig.return_type

    def zipf_impl(a):
        if a <= 1.0:
            raise ValueError("zipf(): a <= 1")
        am1 = a - 1.0
        b = 2.0 ** am1
        while 1:
            U = 1.0 - _random()
            V = _random()
            X = intty(math.floor(U ** (-1.0 / am1)))
            T = (1.0 + 1.0 / X) ** am1
            if X >= 1 and V * X * (T - 1.0) / (b - 1.0) <= (T / b):
                return X

    res = context.compile_internal(builder, zipf_impl, sig, args)
    return impl_ret_untracked(context, builder, sig.return_type, res)

def do_shuffle_impl(arr, rng):

    if not isinstance(arr, types.Buffer):
        raise TypeError("The argument to shuffle() should be a buffer type")

    if rng == "np":
        rand = np.random.randint
    elif rng == "py":
        rand = random.randrange

    if arr.ndim == 1:
        def impl(arr):
            i = arr.shape[0] - 1
            while i > 0:
                j = rand(i + 1)
                arr[i], arr[j] = arr[j], arr[i]
                i -= 1
    else:
        def impl(arr):
            i = arr.shape[0] - 1
            while i > 0:
                j = rand(i + 1)
                arr[i], arr[j] = np.copy(arr[j]), np.copy(arr[i])
                i -= 1

    return impl

@overload(random.shuffle)
def shuffle_impl(arr):
    return do_shuffle_impl(arr, "py")

@overload(np.random.shuffle)
def shuffle_impl(arr):
    return do_shuffle_impl(arr, "np")

@overload(np.random.permutation)
def permutation_impl(x):
    if isinstance(x, types.Integer):
        def permutation_impl(x):
            y = np.arange(x)
            np.random.shuffle(y)
            return y
    elif isinstance(x, types.Array):
        def permutation_impl(x):
            arr_copy = x.copy()
            np.random.shuffle(arr_copy)
            return arr_copy
    else:
        permutation_impl = None
    return permutation_impl


# ------------------------------------------------------------------------
# Array-producing variants of scalar random functions

for typing_key, arity in [
    ("np.random.beta", 3),
    ("np.random.binomial", 3),
    ("np.random.chisquare", 2),
    ("np.random.exponential", 2),
    ("np.random.f", 3),
    ("np.random.gamma", 3),
    ("np.random.geometric", 2),
    ("np.random.gumbel", 3),
    ("np.random.hypergeometric", 4),
    ("np.random.laplace", 3),
    ("np.random.logistic", 3),
    ("np.random.lognormal", 3),
    ("np.random.logseries", 2),
    ("np.random.negative_binomial", 3),
    ("np.random.normal", 3),
    ("np.random.pareto", 2),
    ("np.random.poisson", 2),
    ("np.random.power", 2),
    ("np.random.random", 1),
    ("np.random.random_sample", 1),
    ("np.random.ranf", 1),
    ("np.random.sample", 1),
    ("np.random.randint", 3),
    ("np.random.rayleigh", 2),
    ("np.random.standard_cauchy", 1),
    ("np.random.standard_exponential", 1),
    ("np.random.standard_gamma", 2),
    ("np.random.standard_normal", 1),
    ("np.random.standard_t", 2),
    ("np.random.triangular", 4),
    ("np.random.uniform", 3),
    ("np.random.vonmises", 3),
    ("np.random.wald", 3),
    ("np.random.weibull", 2),
    ("np.random.zipf", 2),
    ]:

    @glue_lowering(typing_key, *(types.Any,) * arity)
    def random_arr(context, builder, sig, args, typing_key=typing_key):

        arrty = sig.return_type
        dtype = arrty.dtype
        scalar_sig = signature(dtype, *sig.args[:-1])
        scalar_args = args[:-1]

        # Allocate array...
        shapes = arrayobj._parse_shape(context, builder, sig.args[-1], args[-1])
        arr = arrayobj._empty_nd_impl(context, builder, arrty, shapes)

        # ... and populate it in natural order
        *mod, fname = typing_key.split('.')
        # Module must be numpy.random
        assert mod == ['np', 'random']
        np_func = getattr(np.random, fname)
        fnty = context.typing_context.resolve_value_type(np_func)
        resolved_sig = fnty.get_call_type(context.typing_context,
                                          scalar_sig.args, {})
        scalar_impl = context.get_function(fnty, resolved_sig)
        with cgutils.for_range(builder, arr.nitems) as loop:
            val = scalar_impl(builder, scalar_args)
            ptr = cgutils.gep(builder, arr.data, loop.index)
            arrayobj.store_item(context, builder, arrty, val, ptr)

        return impl_ret_new_ref(context, builder, sig.return_type, arr._getvalue())


# ------------------------------------------------------------------------
# Irregular aliases: np.random.rand, np.random.randn

@overload(np.random.rand)
def rand(*size):
    if len(size) == 0:
        # Scalar output
        def rand_impl(*size):
            return np.random.random()

    else:
        # Array output
        def rand_impl(*size):
            return np.random.random(size)

    return rand_impl

@overload(np.random.randn)
def randn(*size):
    if len(size) == 0:
        # Scalar output
        def randn_impl(*size):
            return np.random.standard_normal()

    else:
        # Array output
        def randn_impl(*size):
            return np.random.standard_normal(size)

    return randn_impl


# ------------------------------------------------------------------------
# np.random.choice

@overload(np.random.choice)
def choice(a, size=None, replace=True):

    if isinstance(a, types.Array):
        # choice() over an array population
        assert a.ndim == 1
        dtype = a.dtype

        @register_jitable
        def get_source_size(a):
            return len(a)

        @register_jitable
        def copy_source(a):
            return a.copy()

        @register_jitable
        def getitem(a, a_i):
            return a[a_i]

    elif isinstance(a, types.Integer):
        # choice() over an implied arange() population
        dtype = np.intp

        @register_jitable
        def get_source_size(a):
            return a

        @register_jitable
        def copy_source(a):
            return np.arange(a)

        @register_jitable
        def getitem(a, a_i):
            return a_i

    else:
        raise TypeError("np.random.choice() first argument should be "
                        "int or array, got %s" % (a,))

    if size in (None, types.none):
        def choice_impl(a, size=None, replace=True):
            """
            choice() implementation returning a single sample
            (note *replace* is ignored)
            """
            n = get_source_size(a)
            i = np.random.randint(0, n)
            return getitem(a, i)

    else:
        def choice_impl(a, size=None, replace=True):
            """
            choice() implementation returning an array of samples
            """
            n = get_source_size(a)
            if replace:
                out = np.empty(size, dtype)
                fl = out.flat
                for i in range(len(fl)):
                    j = np.random.randint(0, n)
                    fl[i] = getitem(a, j)
                return out
            else:
                # Note we have to construct the array to compute out.size
                # (`size` can be an arbitrary int or tuple of ints)
                out = np.empty(size, dtype)
                if out.size > n:
                    raise ValueError("Cannot take a larger sample than "
                                     "population when 'replace=False'")
                # Get a permuted copy of the source array
                # we need this implementation in order to get the
                # np.random.choice inside numba to match the output
                # of np.random.choice outside numba when np.random.seed
                # is set to the same value
                permuted_a = np.random.permutation(a)
                fl = out.flat
                for i in range(len(fl)):
                    fl[i] = permuted_a[i]
                return out

    return choice_impl


# ------------------------------------------------------------------------
# np.random.multinomial

@overload(np.random.multinomial)
def multinomial(n, pvals, size=None):

    dtype = np.intp

    @register_jitable
    def multinomial_inner(n, pvals, out):
        # Numpy's algorithm for multinomial()
        fl = out.flat
        sz = out.size
        plen = len(pvals)

        for i in range(0, sz, plen):
            # Loop body: take a set of n experiments and fill up
            # fl[i:i + plen] with the distribution of results.

            # Current sum of outcome probabilities
            p_sum = 1.0
            # Current remaining number of experiments
            n_experiments = n
            # For each possible outcome `j`, compute the number of results
            # with this outcome.  This is done by considering the
            # conditional probability P(X=j | X>=j) and running a binomial
            # distribution over the remaining number of experiments.
            for j in range(0, plen - 1):
                p_j = pvals[j]
                n_j = fl[i + j] = np.random.binomial(n_experiments, p_j / p_sum)
                n_experiments -= n_j
                if n_experiments <= 0:
                    # Note the output was initialized to zero
                    break
                p_sum -= p_j
            if n_experiments > 0:
                # The remaining experiments end up in the last bucket
                fl[i + plen - 1] = n_experiments

    if not isinstance(n, types.Integer):
        raise TypeError("np.random.multinomial(): n should be an "
                        "integer, got %s" % (n,))

    if not isinstance(pvals, (types.Sequence, types.Array)):
        raise TypeError("np.random.multinomial(): pvals should be an "
                        "array or sequence, got %s" % (pvals,))

    if size in (None, types.none):
        def multinomial_impl(n, pvals, size=None):
            """
            multinomial(..., size=None)
            """
            out = np.zeros(len(pvals), dtype)
            multinomial_inner(n, pvals, out)
            return out

    elif isinstance(size, types.Integer):
        def multinomial_impl(n, pvals, size=None):
            """
            multinomial(..., size=int)
            """
            out = np.zeros((size, len(pvals)), dtype)
            multinomial_inner(n, pvals, out)
            return out

    elif isinstance(size, types.BaseTuple):
        def multinomial_impl(n, pvals, size=None):
            """
            multinomial(..., size=tuple)
            """
            out = np.zeros(size + (len(pvals),), dtype)
            multinomial_inner(n, pvals, out)
            return out

    else:
        raise TypeError("np.random.multinomial(): size should be int or "
                        "tuple or None, got %s" % (size,))

    return multinomial_impl

# ------------------------------------------------------------------------
# np.random.dirichlet

@overload(np.random.dirichlet)
def dirichlet(alpha, size=None):
    @register_jitable
    def dirichlet_arr(alpha, out):

        # Gamma distribution method to generate a Dirichlet distribution

        for a_val in iter(alpha):
            if a_val <= 0:
                raise ValueError("dirichlet: alpha must be > 0.0")

        a_len = len(alpha)
        size = out.size
        flat = out.flat
        for i in range(0, size, a_len):
            # calculate gamma random numbers per alpha specifications
            norm = 0  # use this to normalize every the group total to 1
            for k, w in enumerate(alpha):
                flat[i + k] = np.random.gamma(w, 1)
                norm += flat[i + k].item()
            for k, w in enumerate(alpha):
                flat[i + k] /= norm

    if not isinstance(alpha, (types.Sequence, types.Array)):
        raise NumbaTypeError(
            "np.random.dirichlet(): alpha should be an "
            "array or sequence, got %s" % (alpha,)
        )

    if size in (None, types.none):

        def dirichlet_impl(alpha, size=None):
            out = np.empty(len(alpha))
            dirichlet_arr(alpha, out)
            return out

    elif isinstance(size, types.Integer):

        def dirichlet_impl(alpha, size=None):
            """
            dirichlet(..., size=int)
            """
            out = np.empty((size, len(alpha)))
            dirichlet_arr(alpha, out)
            return out

    elif isinstance(size, (types.UniTuple)) and isinstance(size.dtype, types.Integer):
        def dirichlet_impl(alpha, size=None):
            """
            dirichlet(..., size=tuple)
            """
            out = np.empty(size + (len(alpha),))
            dirichlet_arr(alpha, out)
            return out

    else:
        raise NumbaTypeError(
            "np.random.dirichlet(): size should be int or "
            "tuple of ints or None, got %s" % size
        )

    return dirichlet_impl

# ------------------------------------------------------------------------
# np.random.noncentral_chisquare

@overload(np.random.noncentral_chisquare)
def noncentral_chisquare(df, nonc, size=None):

    @register_jitable
    def validate_input(df, nonc):
        if df <= 0:
            raise ValueError("df <= 0")
        if nonc < 0:
            raise ValueError("nonc < 0")

    @register_jitable
    def noncentral_chisquare_single(df, nonc):
        # identical to numpy implementation from distributions.c
        # https://github.com/numpy/numpy/blob/c65bc212ec1987caefba0ea7efe6a55803318de9/numpy/random/src/distributions/distributions.c#L797
        
        if np.isnan(nonc):
            return np.nan

        if 1 < df:
            chi2 = np.random.chisquare(df-1)
            n = np.random.standard_normal() + np.sqrt(nonc)
            return chi2 + n * n

        else:
            i = np.random.poisson(nonc/2.0)
            return np.random.chisquare(df + 2 * i)

    if size in (None, types.none):
        def noncentral_chisquare_impl(df, nonc, size=None):
            validate_input(df, nonc)
            return noncentral_chisquare_single(df, nonc)

    elif isinstance(size, types.Integer) or (
        (isinstance(size, (types.UniTuple)) and isinstance(size.dtype, types.Integer))
        ):

        def noncentral_chisquare_impl(df, nonc, size=None):
            validate_input(df, nonc)
            out = np.empty(size)
            out_flat = out.flat
            for idx in range(out.size):
                out_flat[idx] = noncentral_chisquare_single(df, nonc)
            return out

    else:
        raise NumbaTypeError(
            "np.random.noncentral_chisquare(): size should be int or "
            "tuple of ints or None, got %s" % size
        )

    return noncentral_chisquare_impl

