use crate::{
    alphabet::Alphabet,
    errors::{Error, InvalidEncodingError, InvalidLengthError},
};
use core::str;
#[cfg(feature = "alloc")]
use alloc::{string::String, vec::Vec};
#[cfg(doc)]
use crate::{Base64, Base64Bcrypt, Base64Crypt, Base64Unpadded, Base64Url, Base64UrlUnpadded};
const PAD: u8 = b'=';
pub trait Encoding: Alphabet {
    fn decode(src: impl AsRef<[u8]>, dst: &mut [u8]) -> Result<&[u8], Error>;
    fn decode_in_place(buf: &mut [u8]) -> Result<&[u8], InvalidEncodingError>;
    #[cfg(feature = "alloc")]
    fn decode_vec(input: &str) -> Result<Vec<u8>, Error>;
    fn encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, InvalidLengthError>;
    #[cfg(feature = "alloc")]
    fn encode_string(input: &[u8]) -> String;
    fn encoded_len(bytes: &[u8]) -> usize;
}
impl<T: Alphabet> Encoding for T {
    fn decode(src: impl AsRef<[u8]>, dst: &mut [u8]) -> Result<&[u8], Error> {
        let (src_unpadded, mut err) = if T::PADDED {
            let (unpadded_len, e) = decode_padding(src.as_ref())?;
            (&src.as_ref()[..unpadded_len], e)
        } else {
            (src.as_ref(), 0)
        };
        let dlen = decoded_len(src_unpadded.len());
        if dlen > dst.len() {
            return Err(Error::InvalidLength);
        }
        let dst = &mut dst[..dlen];
        let mut src_chunks = src_unpadded.chunks_exact(4);
        let mut dst_chunks = dst.chunks_exact_mut(3);
        for (s, d) in (&mut src_chunks).zip(&mut dst_chunks) {
            err |= Self::decode_3bytes(s, d);
        }
        let src_rem = src_chunks.remainder();
        let dst_rem = dst_chunks.into_remainder();
        err |= !(src_rem.is_empty() || src_rem.len() >= 2) as i16;
        let mut tmp_out = [0u8; 3];
        let mut tmp_in = [b'A'; 4];
        tmp_in[..src_rem.len()].copy_from_slice(src_rem);
        err |= Self::decode_3bytes(&tmp_in, &mut tmp_out);
        dst_rem.copy_from_slice(&tmp_out[..dst_rem.len()]);
        if err == 0 {
            validate_last_block::<T>(src.as_ref(), dst)?;
            Ok(dst)
        } else {
            Err(Error::InvalidEncoding)
        }
    }
    #[allow(clippy::integer_arithmetic)]
    fn decode_in_place(mut buf: &mut [u8]) -> Result<&[u8], InvalidEncodingError> {
        let mut err = if T::PADDED {
            let (unpadded_len, e) = decode_padding(buf)?;
            buf = &mut buf[..unpadded_len];
            e
        } else {
            0
        };
        let dlen = decoded_len(buf.len());
        let full_chunks = buf.len() / 4;
        for chunk in 0..full_chunks {
            #[allow(unsafe_code)]
            unsafe {
                debug_assert!(3 * chunk + 3 <= buf.len());
                debug_assert!(4 * chunk + 4 <= buf.len());
                let p3 = buf.as_mut_ptr().add(3 * chunk) as *mut [u8; 3];
                let p4 = buf.as_ptr().add(4 * chunk) as *const [u8; 4];
                let mut tmp_out = [0u8; 3];
                err |= Self::decode_3bytes(&*p4, &mut tmp_out);
                *p3 = tmp_out;
            }
        }
        let src_rem_pos = 4 * full_chunks;
        let src_rem_len = buf.len() - src_rem_pos;
        let dst_rem_pos = 3 * full_chunks;
        let dst_rem_len = dlen - dst_rem_pos;
        err |= !(src_rem_len == 0 || src_rem_len >= 2) as i16;
        let mut tmp_in = [b'A'; 4];
        tmp_in[..src_rem_len].copy_from_slice(&buf[src_rem_pos..]);
        let mut tmp_out = [0u8; 3];
        err |= Self::decode_3bytes(&tmp_in, &mut tmp_out);
        if err == 0 {
            #[allow(unsafe_code)]
            unsafe {
                debug_assert!(dst_rem_pos + dst_rem_len <= buf.len());
                debug_assert!(dst_rem_len <= tmp_out.len());
                debug_assert!(dlen <= buf.len());
                core::ptr::copy_nonoverlapping(
                    tmp_out.as_ptr(),
                    buf.as_mut_ptr().add(dst_rem_pos),
                    dst_rem_len,
                );
                Ok(buf.get_unchecked(..dlen))
            }
        } else {
            Err(InvalidEncodingError)
        }
    }
    #[cfg(feature = "alloc")]
    fn decode_vec(input: &str) -> Result<Vec<u8>, Error> {
        let mut output = vec![0u8; decoded_len(input.len())];
        let len = Self::decode(input, &mut output)?.len();
        if len <= output.len() {
            output.truncate(len);
            Ok(output)
        } else {
            Err(Error::InvalidLength)
        }
    }
    fn encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, InvalidLengthError> {
        let elen = match encoded_len_inner(src.len(), T::PADDED) {
            Some(v) => v,
            None => return Err(InvalidLengthError),
        };
        if elen > dst.len() {
            return Err(InvalidLengthError);
        }
        let dst = &mut dst[..elen];
        let mut src_chunks = src.chunks_exact(3);
        let mut dst_chunks = dst.chunks_exact_mut(4);
        for (s, d) in (&mut src_chunks).zip(&mut dst_chunks) {
            Self::encode_3bytes(s, d);
        }
        let src_rem = src_chunks.remainder();
        if T::PADDED {
            if let Some(dst_rem) = dst_chunks.next() {
                let mut tmp = [0u8; 3];
                tmp[..src_rem.len()].copy_from_slice(src_rem);
                Self::encode_3bytes(&tmp, dst_rem);
                let flag = src_rem.len() == 1;
                let mask = (flag as u8).wrapping_sub(1);
                dst_rem[2] = (dst_rem[2] & mask) | (PAD & !mask);
                dst_rem[3] = PAD;
            }
        } else {
            let dst_rem = dst_chunks.into_remainder();
            let mut tmp_in = [0u8; 3];
            let mut tmp_out = [0u8; 4];
            tmp_in[..src_rem.len()].copy_from_slice(src_rem);
            Self::encode_3bytes(&tmp_in, &mut tmp_out);
            dst_rem.copy_from_slice(&tmp_out[..dst_rem.len()]);
        }
        debug_assert!(str::from_utf8(dst).is_ok());
        #[allow(unsafe_code)]
        Ok(unsafe { str::from_utf8_unchecked(dst) })
    }
    #[cfg(feature = "alloc")]
    fn encode_string(input: &[u8]) -> String {
        let elen = encoded_len_inner(input.len(), T::PADDED).expect("input is too big");
        let mut dst = vec![0u8; elen];
        let res = Self::encode(input, &mut dst).expect("encoding error");
        debug_assert_eq!(elen, res.len());
        debug_assert!(str::from_utf8(&dst).is_ok());
        #[allow(unsafe_code)]
        unsafe {
            String::from_utf8_unchecked(dst)
        }
    }
    fn encoded_len(bytes: &[u8]) -> usize {
        encoded_len_inner(bytes.len(), T::PADDED).unwrap_or(0)
    }
}
#[inline(always)]
pub(crate) fn decode_padding(input: &[u8]) -> Result<(usize, i16), InvalidEncodingError> {
    if input.len() % 4 != 0 {
        return Err(InvalidEncodingError);
    }
    let unpadded_len = match *input {
        [.., b0, b1] => is_pad_ct(b0)
            .checked_add(is_pad_ct(b1))
            .and_then(|len| len.try_into().ok())
            .and_then(|len| input.len().checked_sub(len))
            .ok_or(InvalidEncodingError)?,
        _ => input.len(),
    };
    let padding_len = input
        .len()
        .checked_sub(unpadded_len)
        .ok_or(InvalidEncodingError)?;
    let err = match *input {
        [.., b0] if padding_len == 1 => is_pad_ct(b0) ^ 1,
        [.., b0, b1] if padding_len == 2 => (is_pad_ct(b0) & is_pad_ct(b1)) ^ 1,
        _ => {
            if padding_len == 0 {
                0
            } else {
                return Err(InvalidEncodingError);
            }
        }
    };
    Ok((unpadded_len, err))
}
fn validate_last_block<T: Alphabet>(encoded: &[u8], decoded: &[u8]) -> Result<(), Error> {
    if encoded.is_empty() && decoded.is_empty() {
        return Ok(());
    }
    #[allow(clippy::integer_arithmetic)]
    fn last_block_start(bytes: &[u8], block_size: usize) -> usize {
        (bytes.len().saturating_sub(1) / block_size) * block_size
    }
    let enc_block = encoded
        .get(last_block_start(encoded, 4)..)
        .ok_or(Error::InvalidEncoding)?;
    let dec_block = decoded
        .get(last_block_start(decoded, 3)..)
        .ok_or(Error::InvalidEncoding)?;
    let mut buf = [0u8; 4];
    let block = T::encode(dec_block, &mut buf)?;
    if block
        .as_bytes()
        .iter()
        .zip(enc_block.iter())
        .fold(0, |acc, (a, b)| acc | (a ^ b))
        == 0
    {
        Ok(())
    } else {
        Err(Error::InvalidEncoding)
    }
}
#[allow(clippy::integer_arithmetic)]
#[inline(always)]
pub(crate) fn decoded_len(input_len: usize) -> usize {
    let k = input_len / 4;
    let l = input_len - 4 * k;
    3 * k + (3 * l) / 4
}
#[allow(clippy::integer_arithmetic)]
#[inline(always)]
fn is_pad_ct(input: u8) -> i16 {
    ((((PAD as i16 - 1) - input as i16) & (input as i16 - (PAD as i16 + 1))) >> 8) & 1
}
#[allow(clippy::integer_arithmetic)]
#[inline(always)]
const fn encoded_len_inner(n: usize, padded: bool) -> Option<usize> {
    match n.checked_mul(4) {
        Some(q) => {
            if padded {
                Some(((q / 3) + 3) & !3)
            } else {
                Some((q / 3) + (q % 3 != 0) as usize)
            }
        }
        None => None,
    }
}