"""The new soundfile backend which will become default in 0.8.0 onward"""
import warnings
from typing import Optional, Tuple

import torch
from torchaudio._internal import module_utils as _mod_utils

from .common import AudioMetaData


if _mod_utils.is_soundfile_available():
    import soundfile

# Mapping from soundfile subtype to number of bits per sample.
# This is mostly heuristical and the value is set to 0 when it is irrelevant
# (lossy formats) or when it can't be inferred.
# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard:
# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony,
# the default seems to be 8 bits but it can be compressed further to 4 bits.
# The dict is inspired from
# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
_SUBTYPE_TO_BITS_PER_SAMPLE = {
    "PCM_S8": 8,  # Signed 8 bit data
    "PCM_16": 16,  # Signed 16 bit data
    "PCM_24": 24,  # Signed 24 bit data
    "PCM_32": 32,  # Signed 32 bit data
    "PCM_U8": 8,  # Unsigned 8 bit data (WAV and RAW only)
    "FLOAT": 32,  # 32 bit float data
    "DOUBLE": 64,  # 64 bit float data
    "ULAW": 8,  # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
    "ALAW": 8,  # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
    "IMA_ADPCM": 0,  # IMA ADPCM.
    "MS_ADPCM": 0,  # Microsoft ADPCM.
    "GSM610": 0,  # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
    "VOX_ADPCM": 0,  # OKI / Dialogix ADPCM
    "G721_32": 0,  # 32kbs G721 ADPCM encoding.
    "G723_24": 0,  # 24kbs G723 ADPCM encoding.
    "G723_40": 0,  # 40kbs G723 ADPCM encoding.
    "DWVW_12": 12,  # 12 bit Delta Width Variable Word encoding.
    "DWVW_16": 16,  # 16 bit Delta Width Variable Word encoding.
    "DWVW_24": 24,  # 24 bit Delta Width Variable Word encoding.
    "DWVW_N": 0,  # N bit Delta Width Variable Word encoding.
    "DPCM_8": 8,  # 8 bit differential PCM (XI only)
    "DPCM_16": 16,  # 16 bit differential PCM (XI only)
    "VORBIS": 0,  # Xiph Vorbis encoding. (lossy)
    "ALAC_16": 16,  # Apple Lossless Audio Codec (16 bit).
    "ALAC_20": 20,  # Apple Lossless Audio Codec (20 bit).
    "ALAC_24": 24,  # Apple Lossless Audio Codec (24 bit).
    "ALAC_32": 32,  # Apple Lossless Audio Codec (32 bit).
}


def _get_bit_depth(subtype):
    if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
        warnings.warn(
            f"The {subtype} subtype is unknown to TorchAudio. As a result, the bits_per_sample "
            "attribute will be set to 0. If you are seeing this warning, please "
            "report by opening an issue on github (after checking for existing/closed ones). "
            "You may otherwise ignore this warning."
        )
    return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0)


_SUBTYPE_TO_ENCODING = {
    "PCM_S8": "PCM_S",
    "PCM_16": "PCM_S",
    "PCM_24": "PCM_S",
    "PCM_32": "PCM_S",
    "PCM_U8": "PCM_U",
    "FLOAT": "PCM_F",
    "DOUBLE": "PCM_F",
    "ULAW": "ULAW",
    "ALAW": "ALAW",
    "VORBIS": "VORBIS",
}


def _get_encoding(format: str, subtype: str):
    if format == "FLAC":
        return "FLAC"
    return _SUBTYPE_TO_ENCODING.get(subtype, "UNKNOWN")


@_mod_utils.requires_soundfile()
def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
    """Get signal information of an audio file.

    Note:
        ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
        ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
        which has a restriction on type annotation due to TorchScript compiler compatiblity.

    Args:
        filepath (path-like object or file-like object):
            Source of audio data.
        format (str or None, optional):
            Not used. PySoundFile does not accept format hint.

    Returns:
        AudioMetaData: meta data of the given audio.

    """
    sinfo = soundfile.info(filepath)
    return AudioMetaData(
        sinfo.samplerate,
        sinfo.frames,
        sinfo.channels,
        bits_per_sample=_get_bit_depth(sinfo.subtype),
        encoding=_get_encoding(sinfo.format, sinfo.subtype),
    )


_SUBTYPE2DTYPE = {
    "PCM_S8": "int8",
    "PCM_U8": "uint8",
    "PCM_16": "int16",
    "PCM_32": "int32",
    "FLOAT": "float32",
    "DOUBLE": "float64",
}


@_mod_utils.requires_soundfile()
def load(
    filepath: str,
    frame_offset: int = 0,
    num_frames: int = -1,
    normalize: bool = True,
    channels_first: bool = True,
    format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
    """Load audio data from file.

    Note:
        The formats this function can handle depend on the soundfile installation.
        This function is tested on the following formats;

        * WAV

            * 32-bit floating-point
            * 32-bit signed integer
            * 16-bit signed integer
            * 8-bit unsigned integer

        * FLAC
        * OGG/VORBIS
        * SPHERE

    By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
    ``float32`` dtype, and the shape of `[channel, time]`.

    .. warning::

       ``normalize`` argument does not perform volume normalization.
       It only converts the sample type to `torch.float32` from the native sample
       type.

       When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
       signed integer, 24-bit signed integer, and 8-bit unsigned integer, by providing ``normalize=False``,
       this function can return integer Tensor, where the samples are expressed within the whole range
       of the corresponding dtype, that is, ``int32`` tensor for 32-bit signed PCM,
       ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. Since torch does not
       support ``int24`` dtype, 24-bit signed PCM are converted to ``int32`` tensors.

       ``normalize`` argument has no effect on 32-bit floating-point WAV and other formats, such as
       ``flac`` and ``mp3``.

       For these formats, this function always returns ``float32`` Tensor with values.

    Note:
        ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
        ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
        which has a restriction on type annotation due to TorchScript compiler compatiblity.

    Args:
        filepath (path-like object or file-like object):
            Source of audio data.
        frame_offset (int, optional):
            Number of frames to skip before start reading data.
        num_frames (int, optional):
            Maximum number of frames to read. ``-1`` reads all the remaining samples,
            starting from ``frame_offset``.
            This function may return the less number of frames if there is not enough
            frames in the given file.
        normalize (bool, optional):
            When ``True``, this function converts the native sample type to ``float32``.
            Default: ``True``.

            If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
            integer type.
            This argument has no effect for formats other than integer WAV type.

        channels_first (bool, optional):
            When True, the returned Tensor has dimension `[channel, time]`.
            Otherwise, the returned Tensor's dimension is `[time, channel]`.
        format (str or None, optional):
            Not used. PySoundFile does not accept format hint.

    Returns:
        (torch.Tensor, int): Resulting Tensor and sample rate.
            If the input file has integer wav format and normalization is off, then it has
            integer type, else ``float32`` type. If ``channels_first=True``, it has
            `[channel, time]` else `[time, channel]`.
    """
    with soundfile.SoundFile(filepath, "r") as file_:
        if file_.format != "WAV" or normalize:
            dtype = "float32"
        elif file_.subtype not in _SUBTYPE2DTYPE:
            raise ValueError(f"Unsupported subtype: {file_.subtype}")
        else:
            dtype = _SUBTYPE2DTYPE[file_.subtype]

        frames = file_._prepare_read(frame_offset, None, num_frames)
        waveform = file_.read(frames, dtype, always_2d=True)
        sample_rate = file_.samplerate

    waveform = torch.from_numpy(waveform)
    if channels_first:
        waveform = waveform.t()
    return waveform, sample_rate


def _get_subtype_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int):
    if not encoding:
        if not bits_per_sample:
            subtype = {
                torch.uint8: "PCM_U8",
                torch.int16: "PCM_16",
                torch.int32: "PCM_32",
                torch.float32: "FLOAT",
                torch.float64: "DOUBLE",
            }.get(dtype)
            if not subtype:
                raise ValueError(f"Unsupported dtype for wav: {dtype}")
            return subtype
        if bits_per_sample == 8:
            return "PCM_U8"
        return f"PCM_{bits_per_sample}"
    if encoding == "PCM_S":
        if not bits_per_sample:
            return "PCM_32"
        if bits_per_sample == 8:
            raise ValueError("wav does not support 8-bit signed PCM encoding.")
        return f"PCM_{bits_per_sample}"
    if encoding == "PCM_U":
        if bits_per_sample in (None, 8):
            return "PCM_U8"
        raise ValueError("wav only supports 8-bit unsigned PCM encoding.")
    if encoding == "PCM_F":
        if bits_per_sample in (None, 32):
            return "FLOAT"
        if bits_per_sample == 64:
            return "DOUBLE"
        raise ValueError("wav only supports 32/64-bit float PCM encoding.")
    if encoding == "ULAW":
        if bits_per_sample in (None, 8):
            return "ULAW"
        raise ValueError("wav only supports 8-bit mu-law encoding.")
    if encoding == "ALAW":
        if bits_per_sample in (None, 8):
            return "ALAW"
        raise ValueError("wav only supports 8-bit a-law encoding.")
    raise ValueError(f"wav does not support {encoding}.")


def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
    if encoding in (None, "PCM_S"):
        return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32"
    if encoding in ("PCM_U", "PCM_F"):
        raise ValueError(f"sph does not support {encoding} encoding.")
    if encoding == "ULAW":
        if bits_per_sample in (None, 8):
            return "ULAW"
        raise ValueError("sph only supports 8-bit for mu-law encoding.")
    if encoding == "ALAW":
        return "ALAW"
    raise ValueError(f"sph does not support {encoding}.")


def _get_subtype(dtype: torch.dtype, format: str, encoding: str, bits_per_sample: int):
    if format == "wav":
        return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
    if format == "flac":
        if encoding:
            raise ValueError("flac does not support encoding.")
        if not bits_per_sample:
            return "PCM_16"
        if bits_per_sample > 24:
            raise ValueError("flac does not support bits_per_sample > 24.")
        return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
    if format in ("ogg", "vorbis"):
        if encoding or bits_per_sample:
            raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.")
        return "VORBIS"
    if format == "sph":
        return _get_subtype_for_sphere(encoding, bits_per_sample)
    if format in ("nis", "nist"):
        return "PCM_16"
    raise ValueError(f"Unsupported format: {format}")


@_mod_utils.requires_soundfile()
def save(
    filepath: str,
    src: torch.Tensor,
    sample_rate: int,
    channels_first: bool = True,
    compression: Optional[float] = None,
    format: Optional[str] = None,
    encoding: Optional[str] = None,
    bits_per_sample: Optional[int] = None,
):
    """Save audio data to file.

    Note:
        The formats this function can handle depend on the soundfile installation.
        This function is tested on the following formats;

        * WAV

            * 32-bit floating-point
            * 32-bit signed integer
            * 16-bit signed integer
            * 8-bit unsigned integer

        * FLAC
        * OGG/VORBIS
        * SPHERE

    Note:
        ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
        ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
        which has a restriction on type annotation due to TorchScript compiler compatiblity.

    Args:
        filepath (str or pathlib.Path): Path to audio file.
        src (torch.Tensor): Audio data to save. must be 2D tensor.
        sample_rate (int): sampling rate
        channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
            otherwise `[time, channel]`.
        compression (float of None, optional): Not used.
            It is here only for interface compatibility reson with "sox_io" backend.
        format (str or None, optional): Override the audio format.
            When ``filepath`` argument is path-like object, audio format is
            inferred from file extension. If the file extension is missing or
            different, you can specify the correct format with this argument.

            When ``filepath`` argument is file-like object,
            this argument is required.

            Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
            ``"flac"`` and ``"sph"``.
        encoding (str or None, optional): Changes the encoding for supported formats.
            This argument is effective only for supported formats, sush as
            ``"wav"``, ``""flac"`` and ``"sph"``. Valid values are;

                - ``"PCM_S"`` (signed integer Linear PCM)
                - ``"PCM_U"`` (unsigned integer Linear PCM)
                - ``"PCM_F"`` (floating point PCM)
                - ``"ULAW"`` (mu-law)
                - ``"ALAW"`` (a-law)

        bits_per_sample (int or None, optional): Changes the bit depth for the
            supported formats.
            When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
            you can change the bit depth.
            Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.

    Supported formats/encodings/bit depth/compression are:

    ``"wav"``
        - 32-bit floating-point PCM
        - 32-bit signed integer PCM
        - 24-bit signed integer PCM
        - 16-bit signed integer PCM
        - 8-bit unsigned integer PCM
        - 8-bit mu-law
        - 8-bit a-law

        Note:
            Default encoding/bit depth is determined by the dtype of
            the input Tensor.

    ``"flac"``
        - 8-bit
        - 16-bit (default)
        - 24-bit

    ``"ogg"``, ``"vorbis"``
        - Doesn't accept changing configuration.

    ``"sph"``
        - 8-bit signed integer PCM
        - 16-bit signed integer PCM
        - 24-bit signed integer PCM
        - 32-bit signed integer PCM (default)
        - 8-bit mu-law
        - 8-bit a-law
        - 16-bit a-law
        - 24-bit a-law
        - 32-bit a-law

    """
    if src.ndim != 2:
        raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
    if compression is not None:
        warnings.warn(
            '`save` function of "soundfile" backend does not support "compression" parameter. '
            "The argument is silently ignored."
        )
    if hasattr(filepath, "write"):
        if format is None:
            raise RuntimeError("`format` is required when saving to file object.")
        ext = format.lower()
    else:
        ext = str(filepath).split(".")[-1].lower()

    if bits_per_sample not in (None, 8, 16, 24, 32, 64):
        raise ValueError("Invalid bits_per_sample.")
    if bits_per_sample == 24:
        warnings.warn(
            "Saving audio with 24 bits per sample might warp samples near -1. "
            "Using 16 bits per sample might be able to avoid this."
        )
    subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)

    # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
    # so we extend the extensions manually here
    if ext in ["nis", "nist", "sph"] and format is None:
        format = "NIST"

    if channels_first:
        src = src.t()

    soundfile.write(file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format)
