#include <BAN/Assert.h>
#include <BAN/Math.h>
#include <BAN/Traits.h>

#include <ctype.h>
#include <math.h>
#include <scanf_impl.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

enum class LengthModifier
{
	none,
	hh,
	h,
	l,
	ll,
	j,
	z,
	t,
	L,
};

struct Conversion
{
	bool			suppress	= false;
	int				field_width	= -1;
	bool			allocate	= false;
	LengthModifier	length		= LengthModifier::none;
	char			conversion	= '\0';
};

Conversion parse_conversion_specifier(const char*& format)
{
	Conversion conversion;

	if (*format == '*')
	{
		conversion.suppress = true;
		format++;
	}

	if (isdigit(*format))
	{
		conversion.field_width = 0;
		while (isdigit(*format))
		{
			conversion.field_width = (conversion.field_width * 10) + (*format - '0');
			format++;
		}
	}

	if (*format == 'm')
	{
		conversion.allocate = true;
		format++;
	}

	if (*format == 'h')
	{
		conversion.length = LengthModifier::h;
		format++;
		if (*format == 'h')
		{
			conversion.length = LengthModifier::hh;
			format++;
		}
	}
	else if (*format == 'l')
	{
		conversion.length = LengthModifier::l;
		format++;
		if (*format == 'l')
		{
			conversion.length = LengthModifier::ll;
			format++;
		}
	}
	else if (*format == 'j')
	{
		conversion.length = LengthModifier::j;
		format++;
	}
	else if (*format == 'z')
	{
		conversion.length = LengthModifier::z;
		format++;
	}
	else if (*format == 't')
	{
		conversion.length = LengthModifier::t;
		format++;
	}
	else if (*format == 'L')
	{
		conversion.length = LengthModifier::L;
		format++;
	}

	conversion.conversion = *format;
	format++;

	return conversion;
}

template<int BASE>
using BASE_TYPE = BAN::integral_constant<int, BASE>;

template<bool UNSIGNED>
using IS_UNSIGNED = BAN::integral_constant<bool, UNSIGNED>;

int scanf_impl(const char* format, va_list arguments, int (*__getc_fun)(void*), void* data)
{
	static constexpr int DONE = -1;
	static constexpr int NONE = -2;

	int nread = 0;
	int conversions = 0;
	int in = NONE;

	enum class ConversionResult
	{
		NONE,
		SUCCESS,
		INPUT_FAILURE,
		MATCH_FAILURE,
	};

	auto get_input =
		[&](bool advance) -> void
		{
			if (in == DONE)
				return;
			if (advance || in == NONE)
			{
				in = __getc_fun(data);
				nread++;
			}
		};

	auto parse_integer_internal =
		[&get_input, &in]<int BASE, typename T>(BASE_TYPE<BASE>, bool negative, int width, T* out) -> ConversionResult
		{
			constexpr auto is_base_digit =
				[](char c) -> bool
				{
					c = tolower(c);
					if ('0' <= c && c <= '9')
						return c - '0' < BASE;
					if ('a' <= c && c <= 'z')
						return c - 'a' + 10 < BASE;
					return false;
				};
			constexpr auto get_base_digit = [](char c) -> T { if (c <= '9') return c - '0'; return tolower(c) - 'a' + 10; };

			if (!is_base_digit(in))
				return ConversionResult::MATCH_FAILURE;

			*out = 0;
			while (width-- && is_base_digit(in))
			{
				*out = (*out * BASE) + get_base_digit(in);
				get_input(true);
			}

			if (negative)
				*out = -*out;

			return ConversionResult::SUCCESS;
		};

	auto parse_integer_typed =
		[&parse_integer_internal, &arguments, &get_input, &in]<int BASE, typename T>(BASE_TYPE<BASE>, bool suppress, int width, T*) -> ConversionResult
		{
			T dummy;
			T* out = suppress ? &dummy : va_arg(arguments, T*);

			bool negative = (in == '-');
			if (in == '-' || in == '+')
			{
				get_input(true);
				if (--width == 0)
					return ConversionResult::MATCH_FAILURE;
			}

			if constexpr(BASE == 0)
			{
				if (in != '0')
					return parse_integer_internal(BASE_TYPE<10>{}, negative, width, out);
				else
				{
					get_input(true);
					if (--width == 0)
					{
						*out = 0;
						return ConversionResult::SUCCESS;
					}
					if ('0' <= in && in <= '7')
						return parse_integer_internal(BASE_TYPE<8>{}, negative, width, out);
					if (tolower(in) == 'x')
					{
						get_input(true);
						if (--width == 0)
							return ConversionResult::MATCH_FAILURE;
						return parse_integer_internal(BASE_TYPE<16>{}, negative, width, out);
					}
					*out = 0;
					return ConversionResult::SUCCESS;
				}
			}

			if constexpr(BASE == 16)
			{
				if (in == '0')
				{
					get_input(true);
					width--;
					if (tolower(in) == 'x')
					{
						get_input(true);
						width--;
					}
					if (width <= 0)
					{
						*out = 0;
						return ConversionResult::SUCCESS;
					}
				}
			}

			return parse_integer_internal(BASE_TYPE<BASE>{}, negative, width, out);
		};

	auto parse_integer =
		[&parse_integer_typed, &get_input, &in]<int BASE, bool UNSIGNED>(BASE_TYPE<BASE>, IS_UNSIGNED<UNSIGNED>, bool suppress, int width, LengthModifier length) -> ConversionResult
		{
			get_input(false);
			while (isspace(in))
				get_input(true);
			if (in == DONE)
				return ConversionResult::INPUT_FAILURE;

			if (width == -1)
				width = __INT_MAX__;

			switch (length)
			{
				case LengthModifier::none:	return parse_integer_typed(BASE_TYPE<BASE>{}, suppress, width, BAN::either_or_t<UNSIGNED, unsigned int*,		int*>			{});
				case LengthModifier::hh:	return parse_integer_typed(BASE_TYPE<BASE>{}, suppress, width, BAN::either_or_t<UNSIGNED, unsigned char*,		signed char*>	{});
				case LengthModifier::h:		return parse_integer_typed(BASE_TYPE<BASE>{}, suppress, width, BAN::either_or_t<UNSIGNED, unsigned short*,		short*>			{});
				case LengthModifier::l:		return parse_integer_typed(BASE_TYPE<BASE>{}, suppress, width, BAN::either_or_t<UNSIGNED, unsigned long*,		long*>			{});
				case LengthModifier::ll:	return parse_integer_typed(BASE_TYPE<BASE>{}, suppress, width, BAN::either_or_t<UNSIGNED, unsigned long long*,	long long*>		{});
				case LengthModifier::j:		return parse_integer_typed(BASE_TYPE<BASE>{}, suppress, width, BAN::either_or_t<UNSIGNED, uintmax_t*,			intmax_t*>		{});
				case LengthModifier::z:		return parse_integer_typed(BASE_TYPE<BASE>{}, suppress, width, BAN::either_or_t<UNSIGNED, size_t*,				ssize_t*>		{});
				case LengthModifier::t:		return parse_integer_typed(BASE_TYPE<BASE>{}, suppress, width, BAN::either_or_t<UNSIGNED, size_t*,				ptrdiff_t*>		{});
											static_assert(sizeof(size_t) == sizeof(ptrdiff_t));
				default:
					return ConversionResult::MATCH_FAILURE;
			}
		};

#if __enable_sse
	auto parse_floating_point_internal =
		[&parse_integer_internal, &get_input, &in]<int BASE, typename T>(BASE_TYPE<BASE>, bool negative, int width, T* out, bool require_start = true) -> ConversionResult
		{
			constexpr auto is_base_digit =
				[](char c) -> bool
				{
					c = tolower(c);
					if ('0' <= c && c <= '9')
						return c - '0' < BASE;
					if ('a' <= c && c <= 'z')
						return c - 'a' + 10 < BASE;
					return false;
				};
			constexpr auto get_base_digit = [](char c) -> T { if (c <= '9') return T(c - '0'); return T(tolower(c) - 'a' + 10); };

			if (require_start && !is_base_digit(in))
				return ConversionResult::MATCH_FAILURE;

			*out = T(0);

			// Parse whole part
			while (width > 0 && is_base_digit(in))
			{
				*out = (*out * BASE) + get_base_digit(in);
				get_input(true);
				width--;
			}
			if (width == 0)
				goto done;

			// Parse fractional part
			if (in == '.')
			{
				get_input(true);
				width--;
				T multiplier = T(1) / T(BASE);
				while (width > 0 && is_base_digit(in))
				{
					*out += get_base_digit(in) * multiplier;
					multiplier /= T(BASE);
					get_input(true);
					width--;
				}
			}
			if (width == 0)
				goto done;

			// Parse exponent
			static_assert(BASE == 10 || BASE == 16);
			if ((BASE == 10 && tolower(in) == 'e') || (BASE == 16 && tolower(in) == 'p'))
			{
				get_input(true);
				width--;
				bool exp_negative = (in == '-');
				if (in == '+' || in == '-')
				{
					get_input(true);
					if (--width == 0)
						goto done;
				}
				int exp;
				if (parse_integer_internal(BASE_TYPE<10>{}, exp_negative, width, &exp) == ConversionResult::SUCCESS)
					*out *= BAN::Math::pow<T>(BASE == 10 ? T(10) : T(2), T(exp));
			}

		done:
			if (negative)
				*out = -*out;

			return ConversionResult::SUCCESS;
		};

	auto parse_floating_point_typed =
		[&parse_floating_point_internal, &arguments, &get_input, &in]<typename T>(bool suppress, int width, T*) -> ConversionResult
		{
			T dummy;
			T* out = suppress ? &dummy : va_arg(arguments, T*);

			get_input(false);
			while (isspace(in))
				get_input(true);
			if (in == DONE)
				return ConversionResult::INPUT_FAILURE;

			bool negative = (in == '-');
			if (in == '-' || in == '+')
			{
				get_input(true);
				if (--width == 0)
					return ConversionResult::MATCH_FAILURE;
			}

			if (tolower(in) == 'i')
			{
				get_input(true);
				if (--width == 0)
					return ConversionResult::MATCH_FAILURE;
				if (tolower(in) != 'n')
					return ConversionResult::MATCH_FAILURE;
				get_input(true);
				if (--width == 0)
					return ConversionResult::MATCH_FAILURE;
				if (tolower(in) != 'f')
					return ConversionResult::MATCH_FAILURE;
				if constexpr(sizeof(T) == sizeof(float))
					*out = HUGE_VALF;
				else if constexpr(sizeof(T) == sizeof(double))
					*out = HUGE_VAL;
				else if constexpr(sizeof(T) == sizeof(long double))
					*out = HUGE_VALL;
				else
					[]<bool flag = false>() { static_assert(flag); };

				if (negative)
					*out = -*out;

				return ConversionResult::SUCCESS;
			}

			if (tolower(in) == 'n')
			{
				get_input(true);
				if (--width == 0)
					return ConversionResult::MATCH_FAILURE;
				if (tolower(in) != 'a')
					return ConversionResult::MATCH_FAILURE;
				get_input(true);
				if (--width == 0)
					return ConversionResult::MATCH_FAILURE;
				if (tolower(in) != 'n')
					return ConversionResult::MATCH_FAILURE;
				if constexpr(sizeof(T) == sizeof(float))
					*out = nanf("");
				else if constexpr(sizeof(T) == sizeof(double))
					*out = nan("");
				else if constexpr(sizeof(T) == sizeof(long double))
					*out = nanl("");
				else
					[]<bool flag = false>() { static_assert(flag); };

				return ConversionResult::SUCCESS;
			}

			if (in == '0' && width > 2)
			{
				get_input(true);
				width--;
				if (tolower(in) == 'x')
				{
					get_input(true);
					width--;
					return parse_floating_point_internal(BASE_TYPE<16>{}, negative, width, out);
				}
				return parse_floating_point_internal(BASE_TYPE<10>{}, negative, width, out, false);
			}

			return parse_floating_point_internal(BASE_TYPE<10>{}, negative, width, out);
		};

	auto parse_floating_point =
		[&parse_floating_point_typed](bool suppress, int width, LengthModifier length) -> ConversionResult
		{
			if (width == -1)
				width = __INT_MAX__;

			switch (length)
			{
				case LengthModifier::none:	return parse_floating_point_typed(suppress, width, static_cast<float*>			(nullptr));
				case LengthModifier::l:		return parse_floating_point_typed(suppress, width, static_cast<double*>			(nullptr));
				case LengthModifier::L:		return parse_floating_point_typed(suppress, width, static_cast<long double*>	(nullptr));
				default:
					return ConversionResult::MATCH_FAILURE;
			}
		};
#endif

	auto parse_string =
		[&arguments, &get_input, &in](uint8_t* mask, bool exclude, bool suppress, bool allocate, int min_len, int max_len, bool terminate) -> ConversionResult
		{
			char* temp_dummy;
			char** outp = nullptr;
			if (suppress)
				;
			else if (allocate)
			{
				outp = va_arg(arguments, char**);
				*outp = (char*)malloc(BUFSIZ);
				if (*outp == nullptr)
					return ConversionResult::MATCH_FAILURE;
			}
			else
			{
				temp_dummy = va_arg(arguments, char*);
				outp = &temp_dummy;
			}

			const uint8_t xor_mask = exclude ? 0xFF : 0x00;

			get_input(false);
			if (in == DONE)
			{
				if (allocate)
					free(*outp);
				*outp = nullptr;
				return ConversionResult::INPUT_FAILURE;
			}

			int len = 0;
			while (len < max_len && in != DONE && ((mask[in / 8] ^ xor_mask) & (1 << (in % 8))))
			{
				len++;
				if (!suppress)
				{
					(*outp)[len - 1] = in;
					if (allocate && len % BUFSIZ == 0)
					{
						char* newp = (char*)realloc(*outp, len + BUFSIZ);
						if (newp == nullptr)
						{
							free(*outp);
							*outp = nullptr;
							return ConversionResult::MATCH_FAILURE;
						}
						*outp = newp;
					}
				}
				get_input(true);
			}
			if (len < min_len)
			{
				if (allocate)
					free(*outp);
				*outp = nullptr;
				return ConversionResult::MATCH_FAILURE;
			}

			if (!suppress && terminate)
				(*outp)[len] = '\0';

			return ConversionResult::SUCCESS;
		};

	while (isspace(*format) || isprint(*format))
	{
		if (*format == '%')
		{
			format++;
			auto conversion = parse_conversion_specifier(format);

			ConversionResult result = ConversionResult::NONE;
			switch (conversion.conversion)
			{
				case 'd': result = parse_integer(BASE_TYPE<10>{}, IS_UNSIGNED<false>{}, conversion.suppress, conversion.field_width, conversion.length); break;
				case 'i': result = parse_integer(BASE_TYPE<0> {}, IS_UNSIGNED<false>{}, conversion.suppress, conversion.field_width, conversion.length); break;
				case 'o': result = parse_integer(BASE_TYPE<8> {}, IS_UNSIGNED<true> {}, conversion.suppress, conversion.field_width, conversion.length); break;
				case 'u': result = parse_integer(BASE_TYPE<10>{}, IS_UNSIGNED<true> {}, conversion.suppress, conversion.field_width, conversion.length); break;
				case 'x': result = parse_integer(BASE_TYPE<16>{}, IS_UNSIGNED<true> {}, conversion.suppress, conversion.field_width, conversion.length); break;
				case 'X': result = parse_integer(BASE_TYPE<16>{}, IS_UNSIGNED<true> {}, conversion.suppress, conversion.field_width, conversion.length); break;
				case 'p': result = parse_integer(BASE_TYPE<16>{}, IS_UNSIGNED<true> {}, conversion.suppress, conversion.field_width, LengthModifier::j); break;
#if __enable_sse
				case 'a': case 'e': case 'f': case 'g':
					result = parse_floating_point(conversion.suppress, conversion.field_width, conversion.length);
					break;
#endif
				case 'S':
					conversion.length = LengthModifier::l;
					// fall through
				case 's':
				{
					int width = conversion.field_width;
					if (width == -1)
						width = __INT_MAX__;
					uint8_t mask[0x100 / 8] {};
					for (int i = 0; i < 0x100; i++)
						if (isspace(i))
							mask[i / 8] |= 1 << (i % 8);
					get_input(false);
					while (isspace(in))
						get_input(true);
					result = parse_string(mask, true, conversion.suppress, conversion.allocate, 1, width, true);
					break;
				}
				case 'C':
					conversion.length = LengthModifier::l;
					// fall through
				case 'c':
				{
					int width = conversion.field_width;
					if (width == -1)
						width = 1;
					uint8_t mask[0x100 / 8] {};
					result = parse_string(mask, true, conversion.suppress, conversion.allocate, width, width, false);
					break;
				}
				case '[':
				{
					int width = conversion.field_width;
					if (width == -1)
						width = __INT_MAX__;

					bool exclude = (*format == '^');
					if (exclude)
						format++;

					uint8_t mask[0x100 / 8] {};
					if (*format == ']')
					{
						mask[']' / 8] |= 1 << (']' % 8);
						format++;
					}

					bool first = true;
					while (*format && *format != ']')
					{
						if (!first && *format == '-' && *(format + 1) != ']')
						{
							int min = BAN::Math::min(*(format - 1), *(format + 1));
							int max = BAN::Math::max(*(format - 1), *(format + 1));
							for (int i = min; i <= max; i++)
								mask[i / 8] |= 1 << (i % 8);
							format += 2;
						}
						else
						{
							mask[*format / 8] |= 1 << (*format % 8);
							format++;
						}
						first = false;
					}

					if (*format == ']')
						result = parse_string(mask, exclude, conversion.suppress, conversion.allocate, 1, width, true);
					else
						result = ConversionResult::MATCH_FAILURE;
					format++;
					break;
				}
				case 'n':
					if (!conversion.suppress)
						*va_arg(arguments, int*) = nread - (in != NONE);
					conversion.suppress = true; // Dont count this as conversion
					result = ConversionResult::SUCCESS;
					break;
				case '%':
					get_input(false);
					if (in == DONE)
					{
						result = ConversionResult::INPUT_FAILURE;
						break;
					}
					if (in != '%')
					{
						result = ConversionResult::MATCH_FAILURE;
						break;
					}
					get_input(true);
					result = ConversionResult::SUCCESS;
					break;
				default:
					result = ConversionResult::MATCH_FAILURE;
					break;
			}

			ASSERT(result != ConversionResult::NONE);

			if (result == ConversionResult::INPUT_FAILURE && conversions == 0)
				return EOF;
			if (result != ConversionResult::SUCCESS)
				return conversions;
			if (!conversion.suppress)
				conversions++;
		}
		else if (isspace(*format))
		{
			get_input(false);
			while (isspace(in))
				get_input(true);
			format++;
		}
		else
		{
			get_input(false);
			if (in != *format)
				break;
			get_input(true);
			format++;
		}
	}

	return conversions;
}