#include "circuit_simulation.hpp"

CircuitSimulation::CircuitSimulation(){
	
}

int CircuitSimulation::allocateExtraVariable() {
	return nonGNDNodes+ extraVariableCount++;
}

void CircuitSimulation::clear(){
	nets.clear();
	components.clear();
	validated=false;
}

void CircuitSimulation::validate(){
	validated = false;
	validatedText = "";
	
	bool fail = false;
	
	for(size_t i=0;i<components.size();i++){
		std::vector<Port*> t = components[i]->getPorts();
		for(size_t j = 0;j<t.size();j++){
			if(t[j]->net == nullptr){
				validatedText+=std::string("Element named ")+components[i]->getName()+ " has an invalid connection at port "+std::to_string(j)+".\n";
				fail=true;
			} else {
				Net* n1 = t[j]->net;
				for(size_t k = 0;k<n1->ports.size();k++){
					if(!n1->ports[k]){
						validatedText+=std::string("Element's net named ")+components[i]->getName()+ " has an invalid connection at port "+std::to_string(j)+" With net's port number "+std::to_string(k)+".\n";
						fail=true;
					}
				}
			}
		}
	}
	
	if(fail == false){
		validatedText= "No violations found in the simulation!\n";
		validated = true;
	}
	
}

void CircuitSimulation::solve(){
	
	validate();
	if(validated == false){
		return;
	}
	
	extraVariableCount=0;
	voltageSourceCount = 0;
	for(size_t i=0;i<components.size();i++){
		if(components[i]->isVoltageSource() == true){
			voltageSourceCount++;
		}
	}
	
	nonGNDNodes = nodeCount(); //Also GND nodes, because of the way they are modelled.
	
	int size = nonGNDNodes + voltageSourceCount ;
	A.resize(size, size);
	b.resize(size);
	x.resize(size);
	
	A.setZero();
	b.setZero();
	x.setZero();

	for(size_t i = 0;i<components.size();i++){
		if(components[i]->isGND() == false){
			
			components[i]->stamp(*this);
			
		}
	}
	
	x = A.colPivHouseholderQr().solve(b);

	// assign node voltages
	for (auto& net : nets) {
		if (!net->isGND()) {
			int nodeIdx = getNodeIndex(net.get()); // use the same mapping as in stamping
			net->voltage = x(nodeIdx);
		} else {
			net->voltage = 0.0;
		}
	}
	
	for(auto& comp: components){
		comp->updateCurrentAndPower(*this);
		comp->setSimulatedString(); //Set cached string for performance.
	}
	
}

CircuitSimulation::~CircuitSimulation(){
	
}

std::vector<ElectricalComponent*> CircuitSimulation::DFS(size_t i){
	
	ElectricalComponent* c = components[i].get();
	//std::cout << "From: " << c.getName() <<std::endl;
	
	std::vector<ElectricalComponent*> explored;
	explored.push_back(c);
	
	const std::vector<Port*>& ps = c->getPorts();
	for(size_t i = 0;i<ps.size();i++){
		Net* currentNet = ps[i]->net;
		
		for(size_t i=0;i<currentNet->ports.size();i++){
			Port* tempPort = currentNet->ports[i];
			if(isContainingId(tempPort->getOwner()->getId(),explored) ==false){
				getNextDepth(tempPort->getOwner(),explored);
				//std::cout << "To: " << tempPort->getOwner()->getName() << std::endl;
			}
		}
		
	}
	return explored;
}

void CircuitSimulation::getNextDepth(ElectricalComponent* c, std::vector<ElectricalComponent*>& explored){
	explored.push_back(c);
	
	const std::vector<Port*>& ps = c->getPorts();
	for(size_t i = 0;i<ps.size();i++){
		Net* currentNet = ps[i]->net;
		
		for(size_t i=0;i<currentNet->ports.size();i++){
			Port* tempPort = currentNet->ports[i];
			if(isContainingId(tempPort->getOwner()->getId(),explored) ==false){
				getNextDepth(tempPort->getOwner(),explored);
				//std::cout << "To: " << tempPort->getOwner()->getName() << std::endl;
			}
		}
		
	}
}


void CircuitSimulation::deleteConnection(Node* a, Node* b){
	auto* p1 = dynamic_cast<Port*>(
		a
	);
	auto* p2 = dynamic_cast<Port*>(
		b
	);
	
	if (!p1 || !p2) return;
    if (!p1->net || p1->net != p2->net) return;

    Net* oldNet = p1->net;

    //Remove serialized connection
    auto removeSerialized = [](CustomElectricalComponent* ref,
                              int portIdx,
                              int otherId,
                              int otherPort)
    {
        ref->connections.erase(
            std::remove_if(ref->connections.begin(), ref->connections.end(),
                [&](const auto& c) {
                    return c.portIndex == portIdx &&
                           c.otherComponentId == otherId &&
                           c.otherPortIndex == otherPort;
                }),
            ref->connections.end()
        );
    };

    if (auto* o1 = p1->getOwner(); auto* o2 = p2->getOwner())
    {
        auto* r1 = o1->serializedRef;
        auto* r2 = o2->serializedRef;

        if (r1 && r2 && r1->id >= 0 && r2->id >= 0)
        {
            removeSerialized(r1, (int)p1->index, (int)r2->id, (int)p2->index);
            removeSerialized(r2, (int)p2->index, (int)r1->id, (int)p1->index);
        }
    }

    //Detach p1 and p2 from the net
    auto& ports = oldNet->ports;

    ports.erase(std::remove(ports.begin(), ports.end(), p1), ports.end());
    ports.erase(std::remove(ports.begin(), ports.end(), p2), ports.end());

    p1->net = nullptr;
    p2->net = nullptr;
}

void CircuitSimulation::connectPort(Port* a, Port* b) {
	if (a->net && b->net) {
		if (a->net == b->net) return; // same net, do nothing

		// merge b's net into a's net
		for (Port* p : b->net->ports) {
			p->net = a->net;              // update pointer
			a->net->ports.push_back(p);   // add to a's net
		}

		// remove b's net from nets vector
		auto it = std::find_if(nets.begin(), nets.end(),
			[b](const std::unique_ptr<Net>& n){ return n.get() == b->net; });
		if (it != nets.end()) nets.erase(it);

	} else if (a->net) {
		a->net->ports.push_back(b);
		b->net = a->net;
	} else if (b->net) {
		b->net->ports.push_back(a);
		a->net = b->net;
	} else {
		auto net = std::make_unique<Net>();
		net->ports = { a, b };
		net->index = nets.size();
		
		a->net = net.get();
		b->net = net.get();
		
		nets.push_back(std::move(net));
	}
	
	ElectricalComponent* aOwner = a->getOwner();
	ElectricalComponent* bOwner = b->getOwner();
	
	if (aOwner && bOwner) {
        CustomElectricalComponent* aRef = aOwner->serializedRef;
        CustomElectricalComponent* bRef = bOwner->serializedRef;

        if (aRef && bRef) {
            // Check if connection already exists to avoid duplicates
            auto exists = [&](const CustomElectricalComponent* ref, int portIdx, int otherId, int otherPort) {
                for (const auto& c : ref->connections) {
                    if (c.portIndex == portIdx && c.otherComponentId == otherId && c.otherPortIndex == otherPort)
                        return true;
                }
                return false;
            };

            // Use CustomElectricalComponent::id for consistency with save/load
            // Only save connections if IDs are valid (>= 0)
            if (aRef->id >= 0 && bRef->id >= 0) {
                int aId = static_cast<int>(aRef->id);
                int bId = static_cast<int>(bRef->id);

                if (!exists(aRef, static_cast<int>(a->index), bId, static_cast<int>(b->index)))
                    aRef->connections.push_back({static_cast<int>(a->index), bId, static_cast<int>(b->index)});

                if (!exists(bRef, static_cast<int>(b->index), aId, static_cast<int>(a->index)))
                    bRef->connections.push_back({static_cast<int>(b->index), aId, static_cast<int>(a->index)});
            }
        }
    }
};

void CircuitSimulation::clearElement(Node* n){
	auto it = std::find_if(
		components.begin(),
		components.end(),
		[n](const std::unique_ptr<ElectricalComponent>& ptr) {
			return ptr.get() == n;
		}
	);

	if (it != components.end()) {
		components.erase(it);
	}
	n = nullptr;
}

/*void CircuitSimulation::connectPortToNet(Port* p, Net* n) {
	if (p->net ==nullptr) {
		n->ports.push_back(p);
		p->net =n;
	} else {
		//TODO error, already has net
	}
};*/
