brot/bignum.lua

local ffi = require "ffi"
local bit = require "bit"

local M = {}

ffi.cdef [[
	typedef struct bigint {
		int8_t sign;
		uint32_t size;
		uint32_t values[?];
	} bigint;
]]

local int_mt = {}

local function normalize(int)
	while int.size > 1 do
		if int.values[int.size - 1] ~= 0 then
			break
		end
		int.size = int.size - 1
	end
	if int.size == 1 and int.values[0] == 0 then
		int.sign = 0
	end
	return int
end

local function unsigned_add(a, b)
	local size = math.max(a.size, b.size) + 1
	local result = ffi.new("bigint", size)

	local carry = 0ULL
	for i = 0, size - 1 do
		local sum = carry
		if i < a.size then
			sum = sum + a.values[i]
		end
		if i < b.size then
			sum = sum + b.values[i]
		end
		result.values[i] = sum
		carry = bit.rshift(sum, 32)
	end

	result.size = size
	result.sign = 1
	return result
end

local function unsigned_sub(a, b)
	local result = ffi.new("bigint", a.size)
	local carry = 0LL
	for i = 0, a.size - 1 do
		local diff = carry + a.values[i]
		if i < b.size then
			diff = diff - b.values[i]
		end
		if diff < 0 then
			diff = diff + 2^32
			carry = -1LL
		else
			carry = 0LL
		end
		result.values[i] = diff
	end

	result.size = a.size
	result.sign = 1
	return result
end
-- some sort of bug or something occurs with this!
require"jit".off(unsigned_sub)

local function unsigned_cmp(a, b)
	if a.size > b.size then
		return 1
	elseif a.size < b.size then
		return -1
	end
	for i = a.size - 1, 0, -1 do
		if a.values[i] > b.values[i] then
			return 1
		elseif b.values[i] > a.values[i] then
			return -1
		end
	end
	return 0
end

local function signed_cmp(a, b)
	if a.sign > b.sign then
		return 1
	elseif a.sign < b.sign then
		return -1
	end
	if a.sign > 0 then
		return unsigned_cmp(a, b)
	else
		return -unsigned_cmp(a, b)
	end
end

local function lshift(int, n)
	if n == 0 then return int end
	if int == 0 then return int end
	assert(n > 0)

	local size = int.size + n / 32 + 1
	local result = ffi.new("bigint", size)
	result.size = size

	for i = 0, int.size - 1 do
		local val = bit.lshift(ffi.cast("uint64_t", int.values[i]), n % 32)
		local index = i + math.floor(n / 32)
		result.values[index] = bit.bor(val, result.values[index])
		result.values[index + 1] = bit.rshift(val, 32)
	end
	result.sign = int.sign

	normalize(result)
	return result
end

function int_mt.__add(a, b)
	a = M.is_bigint(a) and a or M.int(a)
	b = M.is_bigint(b) and b or M.int(b)

	local result
	if a.sign < 0 and b.sign < 0 then
		result = unsigned_add(a, b)
		result.sign = -1
	elseif a.sign < 0 or b.sign < 0 then
		local neg = a.sign < 0 and a or b
		local pos = a.sign < 0 and b or a
		if unsigned_cmp(pos, neg) > 0 then
			result = unsigned_sub(pos, neg)
		else
			result = unsigned_sub(neg, pos)
			result.sign = -1
		end
	else
		result = unsigned_add(a, b)
	end
	normalize(result)
	return result
end

function int_mt.__sub(a, b)
	return a + -b
end

function int_mt.__mul(a, b)
	a = M.is_bigint(a) and a or M.int(a)
	b = M.is_bigint(b) and b or M.int(b)

	local result = ffi.new("bigint", a.size + b.size + 1)
	result.size = a.size + b.size

	for i = 0, b.size - 1 do
		local carry = 0ULL
		for j = 0, a.size - 1 do
			local val =
				carry + result.values[i + j] + a.values[j] * b.values[i]
			carry = bit.rshift(val, 32)
			result.values[i + j] = val
		end
		result.values[i + a.size] = carry
	end
	result.sign = a.sign * b.sign

	normalize(result)
	return result
end

function int_mt.__unm(int)
	local result = ffi.new("bigint", int.size)
	result.size = int.size
	result.sign = -int.sign
	for i = 0, int.size - 1 do
		result.values[i] = int.values[i]
	end
	return result
end

function int_mt.__tostring(int)
	local s = {"0x"}
	if int.sign < 0 then table.insert(s, 1, "-") end

	for i = int.size - 1, 0, -1 do
		if i < int.size - 1 then
			table.insert(s, ("%08x"):format(int.values[i]))
		else
			table.insert(s, ("%x"):format(int.values[i]))
		end
	end
	return table.concat(s)
end

function int_mt.__eq(a, b)
	a = M.is_bigint(a) and a or M.int(a)
	b = M.is_bigint(b) and b or M.int(b)

	return signed_cmp(a, b) == 0
end

function int_mt.__lt(a, b)
	a = M.is_bigint(a) and a or M.int(a)
	b = M.is_bigint(b) and b or M.int(b)

	return signed_cmp(a, b) < 0
end

ffi.metatype("bigint", int_mt)

function M.int(i)
	if M.is_bigint(i) then
		local int = ffi.new("bigint", i.size)
		int.size = i.size
		int.sign = i.sign
		for j = 0, i.size - 1 do
			int.values[j] = i.values[j]
		end
		return int
	end

	i = math.floor(i)
	local size = 1
	if i ~= 0 and math.abs(i) > 2^32 then
		size = math.max(math.floor(math.log(math.abs(i), 2) / 32), 1)
	end
	local int = ffi.new("bigint", size)

	if i < 0 then
		int.sign = -1
		i = -i
	elseif i == 0 then
		int.sign = 0
		int.size = 1
		return int
	else
		int.sign = 1
	end

	int.size = 0
	while i > 0 do
		int.values[int.size] = i % (2^32)
		i = math.floor(i / 2^32)
		int.size = int.size + 1
	end
	
	return int
end

function M.is_bigint(v)
	return type(v) == "cdata" and ffi.typeof(v) == ffi.typeof"bigint"
end

local float_mt = {}

local function truncate(f, precision)
	local precision = precision or f.precision

	if f.mantissa == 0 then f.exp = 0 return f end
	if f.mantissa.size == 1 then return f end
	if precision < 1 then return M.float(0) end

	local to = math.max(f.mantissa.size - precision, 0)
	while f.mantissa.values[to] == 0 do
		to = to + 1
	end

	if to ~= 0 then
		local new = M.float()
		new.exp = f.exp + to * 32
		new.precision = precision

		new.mantissa = ffi.new("bigint", f.mantissa.size - to)
		new.mantissa.size = f.mantissa.size - to
		new.mantissa.sign = f.mantissa.sign
		for i = to, f.mantissa.size - 1 do
			new.mantissa.values[i - to] = f.mantissa.values[i]
		end

		return new
	else
		return f
	end
end

local function cmp_float(a, b)
	if a.mantissa.sign > b.mantissa.sign then
		return 1
	elseif a.mantissa.sign < b.mantissa.sign then
		return -1
	end

	local a_mag = a.mantissa.size + math.ceil(a.exp / 32)
	local b_mag = b.mantissa.size + math.ceil(b.exp / 32)
	if a_mag > b_mag then
		return 1 * a.mantissa.sign
	elseif a_mag < b_mag then
		return -1 * a.mantissa.sign
	end

	return (a - b).mantissa.sign
end

local recursion = true
function float_mt.__add(a, b)
	a = M.is_bigfloat(a) and a or M.float(a)
	b = M.is_bigfloat(b) and b or M.float(b)

	local f = M.float()
	f.precision = math.min(a.precision, b.precision)
	f.exp = math.min(a.exp, b.exp)

	local a = truncate(a, math.floor(f.precision - (a.exp - f.exp) / 32))
	local b = truncate(b, math.floor(f.precision - (b.exp - f.exp) / 32))

	local a_shifted = lshift(a.mantissa, a.exp - f.exp)
	local b_shifted = lshift(b.mantissa, b.exp - f.exp)
	f.mantissa = a_shifted + b_shifted

	return truncate(f)
end

function float_mt.__sub(a, b)
	return a + -b
end

function float_mt.__mul(a, b)
	a = M.is_bigfloat(a) and a or M.float(a)
	b = M.is_bigfloat(b) and b or M.float(b)

	local f = M.float()
	f.exp = a.exp + b.exp
	f.mantissa = a.mantissa * b.mantissa
	f.precision = math.min(a.precision, b.precision)
	return truncate(f)
end

function float_mt.__unm(f)
	local new = M.float()
	new.exp = f.exp
	new.mantissa = -f.mantissa
	new.precision = f.precision
	return new
end

function float_mt.__eq(a, b)
	a = M.is_bigfloat(a) and a or M.float(a)
	b = M.is_bigfloat(b) and b or M.float(b)

	return cmp_float(a, b) == 0
end

function float_mt.__lt(a, b)
	a = M.is_bigfloat(a) and a or M.float(a)
	b = M.is_bigfloat(b) and b or M.float(b)

	return cmp_float(a, b) < 0
end

local function digits_in_range(f, min, max, anchor)
	anchor = anchor or min

	local digits = {0}
	for i = min, max do
		if (i - anchor) % 4 == 0 and i ~= min then
			table.insert(digits, 0)
		end

		local b
		if i >= 0 and i < f.mantissa.size * 32 then
			b = bit.rshift(f.mantissa.values[i / 32], i % 32) or 0
		else
			b = 0
		end
		b = bit.band(b, 1)
		b = bit.lshift(b, (i - anchor) % 4)

		local digit = digits[#digits]
		digit = bit.bor(digit, b)
		digits[#digits] = digit
	end
	return digits
end

function float_mt.__tostring(f)
	local before_point = digits_in_range(f, -f.exp, f.mantissa.size * 32 - 1)
	while before_point[#before_point] == 0 and #before_point > 1 do
		table.remove(before_point, #before_point)
	end

	local after_point = digits_in_range(f, 0, -f.exp - 1, -f.exp)
	while after_point[1] == 0 and #after_point > 1 do
		table.remove(after_point, 1)
	end

	local result = {"0x"}
	if f.mantissa < 0 then
		table.insert(result, 1, "-")
	end
	local hex = "0123456789abcdef"
	for i = #before_point, 1, -1 do
		local digit = before_point[i]
		table.insert(result, hex:sub(digit + 1, digit + 1))
	end
	table.insert(result, ".")
	for i = #after_point, 1, -1 do
		local digit = after_point[i]
		table.insert(result, hex:sub(digit + 1, digit + 1))
	end
	return table.concat(result)
end

function M.float(f, precision)
	local new = setmetatable({}, float_mt)
	new.precision = precision or math.huge

	if f == 0 or f == nil then
		new.mantissa = M.int(0)
		new.exp = 0
		return new
	elseif M.is_bigfloat(f) then
		new.mantissa = f.mantissa
		new.exp = f.exp
		return new
	elseif M.is_bigint(f) then
		new.mantissa = f
		new.exp = 0
		return new
	end
	
	local m, e = math.frexp(f)
	new.mantissa, new.exp = M.int(m * 2^53), e - 53
	return truncate(new)
end

function M.is_bigfloat(v)
	return getmetatable(v) == float_mt
end

function M.tonumber(n)
	local result
	if type(n) == "number" then
		result = n
	elseif M.is_bigint(n) then
		result = 0
		for i = 0, n.size - 1 do
			result = result + n.values[i] * 2^(32 * i)
		end
		result = result * n.sign
	elseif M.is_bigfloat(n) then
		result = 0
		for i = 0, n.mantissa.size - 1 do
			result = result + n.mantissa.values[i] * 2^(32 * i + n.exp)
		end
		result = result * n.mantissa.sign
	end
	return result
end

return M