import { useEffect } from 'react';
import { Node, Edge, Position, ReactFlowState, useStore, useReactFlow } from 'reactflow';
import dagre from 'dagre';
import { DirectionalLayoutModel } from '../Models/DirectionalLayoutModel';

const positionMap: Record<string, Position> = {
    T: Position.Top,
    L: Position.Left,
    R: Position.Right,
    B: Position.Bottom,
};

function nodeCountSelector(state: ReactFlowState): number {
    return state.nodeInternals.size;
}

function nodesInitializedSelector(state: ReactFlowState): boolean {
    return Array.from(state.nodeInternals.values()).every((node) => node.width && node.height);
}

export function useAutoLayout(directionalLayoutModel: DirectionalLayoutModel) {
    const nodeCount = useStore(nodeCountSelector);
    const nodesInitialized = useStore(nodesInitializedSelector);
    const { getNodes, getEdges, setNodes, setEdges, fitView } = useReactFlow();

    useEffect(() => {
        if (!nodeCount || !nodesInitialized) {
            return;
        }

        const nodes = getNodes();
        const edges = getEdges();

        const dagreGraph = new dagre.graphlib!.Graph();
        dagreGraph.setDefaultEdgeLabel(() => ({}));

        dagreGraph.setGraph({
            rankdir: directionalLayoutModel.layoutDirection,
            ranksep: directionalLayoutModel.rankSeparation,
            nodesep: directionalLayoutModel.nodeSeparation,
            ranker: directionalLayoutModel.rankAssignmentAlgorithm,
        });

        nodes.forEach((node: Node) => {
            dagreGraph.setNode(node.id, {
                width: node.width,
                height: node.height,
            });
        });

        edges.forEach((edge: Edge) => {
            dagreGraph.setEdge(edge.source, edge.target);
        });

        dagre.layout(dagreGraph);

        setNodes((nodes) =>
            nodes.map((node) => {
                const { x, y } = dagreGraph.node(node.id);

                return {
                    ...node,
                    sourcePosition: positionMap[directionalLayoutModel.layoutDirection[1]],
                    targetPosition: positionMap[directionalLayoutModel.layoutDirection[0]],
                    position: {
                        x: x - 0.5 * (node?.width ?? 0),
                        y: y - 0.5 * (node?.height ?? 0),
                    },
                    style: { opacity: 1 },
                };
            }),
        );

        setEdges((edges) => edges.map((edge) => ({ ...edge })));
    }, [nodeCount, nodesInitialized, getNodes, getEdges, setNodes, setEdges, fitView, directionalLayoutModel]);
}
