#include <cassert>
#include <cstring>

#include "bus.h"  // for (at least) ADDR_PSW
#include "gen.h"
#include "log.h"
#include "mmu.h"
#include "utils.h"


mmu::mmu()
{
}

mmu::~mmu()
{
}

void mmu::begin(memory *const m)
{
	this->m = m;

	reset();
}

void mmu::reset()
{
	memset(pages, 0x00, sizeof pages);

	CPUERR = MMR0 = MMR1 = MMR2 = MMR3 = PIR = CSR = 0;
}

uint16_t mmu::read_pdr(const uint32_t a, const int run_mode)
{
	int      page = (a >> 1) & 7;
	bool     is_d = a & 16;
	uint16_t t    = pages[run_mode][is_d][page].pdr;

	return t;
}

uint16_t mmu::read_par(const uint32_t a, const int run_mode)
{
	int      page = (a >> 1) & 7;
	bool     is_d = a & 16;
	uint16_t t    = pages[run_mode][is_d][page].par;

	return t;
}

void mmu::setMMR0(uint16_t value)
{
	value &= ~(3 << 10);  // bit 10 & 11 always read as 0

	if (value & 1)
		value &= ~(7l << 13);  // reset error bits

	if (MMR0 & 0160000) {
		if ((value & 1) == 0)
			value &= 254;  // bits 7...1 are protected 
	}

	MMR0 = value;
}

void mmu::setMMR0Bit(const int bit)
{
	assert(bit != 10 && bit != 11);
	assert(bit < 16 && bit >= 0);

	MMR0 |= 1 << bit;
}

void mmu::clearMMR0Bit(const int bit)
{
	assert(bit != 10 && bit != 11);
	assert(bit < 16 && bit >= 0);

	MMR0 &= ~(1 << bit);
}

void mmu::setMMR2(const uint16_t value) 
{
	MMR2 = value;
}

void mmu::setMMR3(const uint16_t value) 
{
	MMR3 = value;
}

bool mmu::get_use_data_space(const int run_mode) const
{
	constexpr const int di_ena_mask[4] = { 4, 2, 0, 1 };

	return !!(MMR3 & di_ena_mask[run_mode]);
}

void mmu::clearMMR1()
{
	MMR1 = 0;
}

void mmu::addToMMR1(const int8_t delta, const uint8_t reg)
{
	assert(reg >= 0 && reg <= 7);
	assert(delta >= -2 && delta <= 2);

	assert((getMMR0() & 0160000) == 0);  // MMR1 should not be locked

	MMR1 <<= 8;

	MMR1 |= (delta & 31) << 3;
	MMR1 |= reg;
}

void mmu::write_pdr(const uint32_t a, const int run_mode, const uint16_t value, const word_mode_t word_mode)
{
	bool is_d = a & 16;
	int  page = (a >> 1) & 7;

	if (word_mode == wm_byte) {
		assert(a != 0 || value < 256);

		update_word(&pages[run_mode][is_d][page].pdr, a & 1, value);
	}
	else {
		pages[run_mode][is_d][page].pdr = value;
	}

	pages[run_mode][is_d][page].pdr &= ~(32768 + 128 /*A*/ + 64 /*W*/ + 32 + 16);  // set bit 4, 5 & 15 to 0 as they are unused and A/W are set to 0 by writes

	TRACE("mmu WRITE-I/O PDR run-mode %d: %c for %d: %o [%d]", run_mode, is_d ? 'D' : 'I', page, value, word_mode);
}

void mmu::write_par(const uint32_t a, const int run_mode, const uint16_t value, const word_mode_t word_mode)
{
	bool is_d = a & 16;
	int  page = (a >> 1) & 7;

	if (word_mode == wm_byte)
		update_word(&pages[run_mode][is_d][page].par, a & 1, value);
	else
		pages[run_mode][is_d][page].par = value;

	pages[run_mode][is_d][page].pdr &= ~(128 /*A*/ + 64 /*W*/);  // reset PDR A/W when PAR is written to

	TRACE("mmu WRITE-I/O PAR run-mode %d: %c for %d: %o (%07o)", run_mode, is_d ? 'D' : 'I', page, word_mode == wm_byte ? value & 0xff : value, pages[run_mode][is_d][page].par * 64);
}

uint16_t mmu::read_word(const uint16_t a)
{
	uint16_t v = 0;

	if (a >= ADDR_PDR_SV_START && a < ADDR_PDR_SV_END)
		v = read_pdr(a, 1);
	else if (a >= ADDR_PAR_SV_START && a < ADDR_PAR_SV_END)
		v = read_par(a, 1);
	else if (a >= ADDR_PDR_K_START && a < ADDR_PDR_K_END)
		v = read_pdr(a, 0);
	else if (a >= ADDR_PAR_K_START && a < ADDR_PAR_K_END)
		v = read_par(a, 0);
	else if (a >= ADDR_PDR_U_START && a < ADDR_PDR_U_END)
		v = read_pdr(a, 3);
	else if (a >= ADDR_PAR_U_START && a < ADDR_PAR_U_END)
		v = read_par(a, 3);

	return v;
}

uint8_t mmu::read_byte(const uint16_t addr)
{
	uint16_t v = read_word(addr);

	if (addr & 1)
		return v >> 8;

	return v;
}

void mmu::write_word(const uint16_t a, const uint16_t value)
{
	// supervisor
	if (a >= ADDR_PDR_SV_START && a < ADDR_PDR_SV_END)
		write_pdr(a, 1, value, wm_word);
	else if (a >= ADDR_PAR_SV_START && a < ADDR_PAR_SV_END)
		write_par(a, 1, value, wm_word);
	// kernel
	else if (a >= ADDR_PDR_K_START && a < ADDR_PDR_K_END)
		write_pdr(a, 0, value, wm_word);
	else if (a >= ADDR_PAR_K_START && a < ADDR_PAR_K_END)
		write_par(a, 0, value, wm_word);
	// user
	else if (a >= ADDR_PDR_U_START && a < ADDR_PDR_U_END)
		write_pdr(a, 3, value, wm_word);
	else if (a >= ADDR_PAR_U_START && a < ADDR_PAR_U_END)
		write_par(a, 3, value, wm_word);
}

void mmu::write_byte(const uint16_t a, const uint8_t value)
{
	// supervisor
	if (a >= ADDR_PDR_SV_START && a < ADDR_PDR_SV_END)
		write_pdr(a, 1, value, wm_byte);
	else if (a >= ADDR_PAR_SV_START && a < ADDR_PAR_SV_END)
		write_par(a, 1, value, wm_byte);
	// kernel
	else if (a >= ADDR_PDR_K_START && a < ADDR_PDR_K_END)
		write_pdr(a, 0, value, wm_byte);
	else if (a >= ADDR_PAR_K_START && a < ADDR_PAR_K_END)
		write_par(a, 0, value, wm_byte);
	// user
	else if (a >= ADDR_PDR_U_START && a < ADDR_PDR_U_END)
		write_pdr(a, 3, value, wm_byte);
	else if (a >= ADDR_PAR_U_START && a < ADDR_PAR_U_END)
		write_par(a, 3, value, wm_byte);
}

void mmu::trap_if_odd(const uint16_t a, const int run_mode, const d_i_space_t space, const bool is_write)
{
	int page = a >> 13;

	if (is_write)
		set_page_trapped(run_mode, space == d_space, page);

	MMR0 &= ~(7 << 1);
	MMR0 |= page << 1;
}

memory_addresses_t mmu::calculate_physical_address(const int run_mode, const uint16_t a) const
{
	const uint8_t apf = a >> 13; // active page field

	if (is_enabled() == false) {
		bool is_psw = a == ADDR_PSW;
		return { a, apf, a, is_psw, a, is_psw };
	}

	uint32_t physical_instruction = get_physical_memory_offset(run_mode, 0, apf);
	uint32_t physical_data        = get_physical_memory_offset(run_mode, 1, apf);

	uint16_t p_offset = a & 8191;  // page offset

	physical_instruction += p_offset;
	physical_data        += p_offset;

	if ((getMMR3() & 16) == 0) {  // offset is 18bit
		physical_instruction &= 0x3ffff;
		physical_data        &= 0x3ffff;
	}

	if (get_use_data_space(run_mode) == false)
		physical_data = physical_instruction;

	uint32_t io_base                     = get_io_base();
	bool     physical_instruction_is_psw = (physical_instruction - io_base + 0160000) == ADDR_PSW;
	bool     physical_data_is_psw        = (physical_data        - io_base + 0160000) == ADDR_PSW;

	return { a, apf, physical_instruction, physical_instruction_is_psw, physical_data, physical_data_is_psw };
}

std::pair<trap_action_t, int> mmu::get_trap_action(const int run_mode, const bool d, const int apf, const bool is_write)
{
	const int     access_control = get_access_control(run_mode, d, apf);

	trap_action_t trap_action    = T_PROCEED;

	if (access_control == 0)
		trap_action = T_ABORT_4;
	else if (access_control == 1)
		trap_action = is_write ? T_ABORT_4 : T_TRAP_250;
	else if (access_control == 2) {
		if (is_write)
			trap_action = T_ABORT_4;
	}
	else if (access_control == 3)
		trap_action = T_ABORT_4;
	else if (access_control == 4)
		trap_action = T_TRAP_250;
	else if (access_control == 5) {
		if (is_write)
			trap_action = T_TRAP_250;
	}
	else if (access_control == 6) {
		// proceed
	}
	else if (access_control == 7) {
		trap_action = T_ABORT_4;
	}

	return { trap_action, access_control };
}

void mmu::mmudebug(const uint16_t a)
{
	for(int rm=0; rm<4; rm++) {
		auto ma = calculate_physical_address(rm, a);

		TRACE("RM %d, a: %06o, apf: %d, PI: %08o (PSW: %d), PD: %08o (PSW: %d)", rm, ma.virtual_address, ma.apf, ma.physical_instruction, ma.physical_instruction_is_psw, ma.physical_data, ma.physical_data_is_psw);
	}
}

uint32_t mmu::calculate_physical_address(cpu *const c, const int run_mode, const uint16_t a, const bool trap_on_failure, const bool is_write, const bool peek_only, const d_i_space_t space)
{
	uint32_t m_offset = a;

	if (is_enabled() || (is_write && (getMMR0() & (1 << 8 /* maintenance check */)))) {
		uint8_t  apf      = a >> 13; // active page field

		bool     d        = space == d_space && get_use_data_space(run_mode);

		uint16_t p_offset = a & 8191;  // page offset

		m_offset  = get_physical_memory_offset(run_mode, d, apf);

		m_offset += p_offset;

		if ((getMMR3() & 16) == 0)  // off is 18bit
			m_offset &= 0x3ffff;

		uint32_t io_base  = get_io_base();
		bool     is_io    = m_offset >= io_base;

		if (trap_on_failure) {
			{
				auto rc = get_trap_action(run_mode, d, apf, is_write);
				auto trap_action    = rc.first;
				int  access_control = rc.second;

				if (trap_action != T_PROCEED) [[unlikely]] {
					if (is_write)
						set_page_trapped(run_mode, d, apf);

					if (is_locked() == false) {
						uint16_t temp = getMMR0();

						temp &= ~((1l << 15) | (1 << 14) | (1 << 13) | (1 << 12) | (3 << 5) | (7 << 1) | (1 << 4));

						if (is_write && access_control != 6)
							temp |= 1 << 13;  // read-only
								  //
						if (access_control == 0 || access_control == 4)
							temp |= 1l << 15;  // not resident
						else
							temp |= 1 << 13;  // read-only

						temp |= run_mode << 5;  // TODO: kernel-mode or user-mode when a trap occurs in user-mode?

						temp |= apf << 1; // add current page

						temp |= d << 4;

						setMMR0(temp);

						TRACE("MMR0: %06o", temp);
					}

					if (trap_action == T_TRAP_250) {
						TRACE("Page access %d (for virtual address %06o): trap 0250", access_control, a);

						c->trap(0250);  // trap

						throw 5;
					}
					else {  // T_ABORT_4
						TRACE("Page access %d (for virtual address %06o): trap 004", access_control, a);

						c->trap(004);  // abort

						throw 5;
					}
				}
			}

			if (m_offset >= m->get_memory_size() && !is_io) [[unlikely]] {
				if (!peek_only)
					TRACE("mmu::calculate_physical_address %o >= %o", m_offset, m->get_memory_size());
				TRACE("TRAP(04) (throw 6) on address %06o", a);

				if (is_locked() == false) {
					uint16_t temp = getMMR0();

					temp &= 017777;
					temp |= 1l << 15;  // non-resident

					temp &= ~14;  // add current page
					temp |= apf << 1;

					temp &= ~(3 << 5);
					temp |= run_mode << 5;

					setMMR0(temp);
				}

				if (is_write)
					set_page_trapped(run_mode, d, apf);

				c->trap(04);

				throw 6;
			}

			uint16_t pdr_len = get_pdr_len(run_mode, d, apf);
			uint16_t pdr_cmp = (a >> 6) & 127;

			bool direction = get_pdr_direction(run_mode, d, apf);

			// TRACE("p_offset %06o pdr_len %06o direction %d, run_mode %d, apf %d, pdr: %06o", p_offset, pdr_len, direction, run_mode, apf, pages[run_mode][d][apf].pdr);

			if ((pdr_cmp > pdr_len && direction == false) || (pdr_cmp < pdr_len && direction == true)) [[unlikely]] {
				TRACE("mmu::calculate_physical_address::p_offset %o versus %o direction %d", pdr_cmp, pdr_len, direction);
				TRACE("TRAP(0250) (throw 7) on address %06o", a);
				c->trap(0250);  // invalid access

				if (is_locked() == false) {
					uint16_t temp = getMMR0();

					temp &= 017777;
					temp |= 1 << 14;  // length

					temp &= ~14;  // add current page
					temp |= apf << 1;

					temp &= ~(3 << 5);
					temp |= run_mode << 5;

					temp &= ~(1 << 4);
					temp |= d << 4;

					setMMR0(temp);
				}

				if (is_write)
					set_page_trapped(run_mode, d, apf);

				throw 7;
			}
		}

		TRACE("virtual address %06o maps to physical address %08o (run_mode: %d, apf: %d, par: %08o, poff: %o, AC: %d, %s)", a, m_offset, run_mode, apf,
				get_physical_memory_offset(run_mode, d, apf),
				p_offset, get_access_control(run_mode, d, apf), d ? "D" : "I");
	}
	else {
		// TRACE("no MMU (read physical address %08o)", m_offset);
	}

	return m_offset;
}

#if IS_POSIX
void mmu::add_par_pdr(json_t *const target, const int run_mode, const bool is_d, const std::string & name) const
{
	json_t *j = json_object();

	json_t *ja_par = json_array();
	for(int i=0; i<8; i++)
		json_array_append(ja_par, json_integer(pages[run_mode][is_d][i].par));
	json_object_set(j, "par", ja_par);

	json_t *ja_pdr = json_array();
	for(int i=0; i<8; i++)
		json_array_append(ja_pdr, json_integer(pages[run_mode][is_d][i].pdr));
	json_object_set(j, "pdr", ja_pdr);

	json_object_set(target, name.c_str(), j);
}

json_t *mmu::serialize() const
{
	json_t *j = json_object();

	for(int run_mode=0; run_mode<4; run_mode++) {
		if (run_mode == 2)
			continue;

		for(int is_d=0; is_d<2; is_d++)
			add_par_pdr(j, run_mode, is_d, format("runmode_%d_d_%d", run_mode, is_d));
	}

        json_object_set(j, "MMR0", json_integer(MMR0));
        json_object_set(j, "MMR1", json_integer(MMR1));
        json_object_set(j, "MMR2", json_integer(MMR2));
        json_object_set(j, "MMR3", json_integer(MMR3));
        json_object_set(j, "CPUERR", json_integer(CPUERR));
        json_object_set(j, "PIR", json_integer(PIR));
        json_object_set(j, "CSR", json_integer(CSR));

	return j;
}

void mmu::set_par_pdr(const json_t *const j_in, const int run_mode, const bool is_d, const std::string & name)
{
	json_t *j = json_object_get(j_in, name.c_str());

	json_t *j_par = json_object_get(j, "par");
	for(int i=0; i<8; i++)
		pages[run_mode][is_d][i].par = json_integer_value(json_array_get(j_par, i));
	json_t *j_pdr = json_object_get(j, "pdr");
	for(int i=0; i<8; i++)
		pages[run_mode][is_d][i].pdr = json_integer_value(json_array_get(j_pdr, i));
}

mmu *mmu::deserialize(const json_t *const j, memory *const mem)
{
	mmu *m = new mmu();
	m->begin(mem);

	for(int run_mode=0; run_mode<4; run_mode++) {
		if (run_mode == 2)
			continue;

		for(int is_d=0; is_d<2; is_d++)
			m->set_par_pdr(j, run_mode, is_d, format("runmode_%d_d_%d", run_mode, is_d));
	}

        m->MMR0   = json_integer_value(json_object_get(j, "MMR0"));
        m->MMR1   = json_integer_value(json_object_get(j, "MMR1"));
        m->MMR2   = json_integer_value(json_object_get(j, "MMR2"));
        m->MMR3   = json_integer_value(json_object_get(j, "MMR3"));
        m->CPUERR = json_integer_value(json_object_get(j, "CPUERR"));
        m->PIR    = json_integer_value(json_object_get(j, "PIR"));
        m->CSR    = json_integer_value(json_object_get(j, "CSR"));

	return m;
}
#endif