import { hierarchy as d3Hierarchy, tree as d3Tree } from 'd3-hierarchy';
import { LayoutTreeNodeMap, LayoutTree, LinkMap, Point, Tree, TreeNode, Link } from '../tree.typing';
import { NODE_HEIGHT, SIMULATION_ROOT } from '../tree.const';

const nodeWidth = 240;

export function getNodePoint(tree: Tree, nodeId: string): Point {
  const nodes = getNodeLayout(tree, nodeWidth, NODE_HEIGHT);
  const node = nodes[nodeId];
  if (!node) {
    return node;
  }
  const { x, y } = node;
  return { x, y };
}

export function computeTreeLayout(tree: Tree | undefined): LayoutTree {
  if (!tree) {
    return { nodes: {}, links: {} };
  }
  const nodes = getNodeLayout(tree, nodeWidth, NODE_HEIGHT);
  const links = getLinks(nodes);
  return { nodes, links };
}

function getNodeLayout(tree: Tree, nodeWidth: number, nodeHeight: number): LayoutTreeNodeMap {
  const outputMap: LayoutTreeNodeMap = {};

  if (!Object.keys(tree.nodes).length) {
    return outputMap;
  }
  const isCompact = Object.keys(tree.nodes).length >= 30;
  const defaultScale = 1;
  const compactScale = 0.85;
  const scale = isCompact ? compactScale : defaultScale;
  const subTeamGap = 1.4;

  const hierarchy = d3Hierarchy(tree.nodes[tree.rootId], (node) => node?.children.map((nodeId) => tree.nodes[nodeId]));

  const layout = d3Tree<TreeNode>()
    .nodeSize([nodeWidth, nodeHeight])
    .separation((a, b) => {
      let separation = scale * subTeamGap;
      if (a.parent && a.parent === b.parent) {
        const siblingCount = tree.nodes[a.parent.data.id].children.length;
        separation = siblingCount >= 10 ? compactScale : scale;
      }
      return separation;
    });

  layout(hierarchy).each(({ x, y, data }) => {
    outputMap[data.id] = { ...data, x, y };
  });

  return outputMap;
}

function getLinks(layoutNodeMap: LayoutTreeNodeMap, nodeHeight: number = 320): LinkMap {
  const linkMap: LinkMap = {};

  for (const [id, { x, y, children }] of Object.entries(layoutNodeMap)) {
    if (!children.length) {
      continue;
    }

    const connections: Point[] = children
      .filter((child) => layoutNodeMap[child].parent !== SIMULATION_ROOT.id)
      .map((child) => {
        const { x, y } = layoutNodeMap[child];
        return { x, y };
      });

    if (!connections?.length) {
      continue;
    }

    const { minX, maxX } = connections.reduce(
      ({ minX, maxX }, { x }) => ({
        minX: Math.min(minX, x),
        maxX: Math.max(maxX, x),
      }),
      { minX: Infinity, maxX: -Infinity }
    );

    const width = Math.abs(minX - maxX);

    const link: Link = {
      parentX: x,
      y: y + nodeHeight - 50,
      x: -width / 2 + x,
      width,
      connections,
    };

    linkMap[id] = link;
  }

  return linkMap;
}
