import * as d3 from 'd3';
import { sankey as d3Sankey, sankeyJustify } from 'd3-sankey';
import { linkHorizontal } from 'd3-shape';
import React, { useEffect } from 'react';

import { SankeyLink, SankeyNode } from './types';

type UseSankeyDiagramParams = {
  height?: number,
  sankeyNodes: SankeyNode[],
  sankeyLinks: SankeyLink[],
  htmlDivElementRef: React.RefObject<HTMLDivElement>
}

export const useSankeyDiagram = ({
  height = 600,
  sankeyNodes,
  sankeyLinks,
  htmlDivElementRef
}: UseSankeyDiagramParams) => {
  function drawSankey() {
    if (!htmlDivElementRef.current || sankeyNodes.length <= 0 || sankeyLinks.length <= 0) { return; }
    if (sankeyLinks.length > 0 && !sankeyLinks.some((link) => sankeyNodes.some((node) => node.name === link.source))) return;
    if (sankeyLinks.length > 0 && !sankeyLinks.some((link) => sankeyNodes.some((node) => node.name === link.target))) return;

    d3.select(htmlDivElementRef.current).html(null);

    const { width } = document.body.getBoundingClientRect();
    const format = d3.format(',.0f');

    const sankey = d3Sankey()
      .nodeAlign(sankeyJustify)
      .nodeId((d) => d.name)
      .extent([[1, 5], [width - 1, height - 5]])
      .nodePadding(10)
      .nodeWidth(15);

    const { nodes, links } = sankey({
      nodes: sankeyNodes.map((node) => ({ ...node })),
      links: sankeyLinks.map((link) => ({ ...link }))
    });

    const path = linkHorizontal()
      .source((d) => [d.source.x1, d.y0])
      .target((d) => [d.target.x0, d.y1]);

    const svg = d3.select(htmlDivElementRef.current)
      .append('svg')
      .attr('viewBox', [0, 0, width, height])
      .attr('width', width)
      .attr('height', height)
      .attr('style', 'max-width: 100%; height: auto; font: 12px sans-serif;');

    const color = d3.scaleOrdinal(d3.schemeCategory10);
    const rect = svg.append('g')
      .attr('stroke', '#000')
      .selectAll()
      .data(nodes)
      .join('rect')
      .attr('x', (d) => d.x0)
      .attr('y', (d) => d.y0)
      .attr('height', (d) => d.y1 - d.y0)
      .attr('width', (d) => d.x1 - d.x0)
      .attr('fill', (d) => color(d.category));

    rect.append('title')
      .text((d) => `${d.name}\n${format(d.value)} USD`);

    const link = svg.append('g')
      .attr('fill', 'none')
      .attr('stroke-opacity', 0.5)
      .selectAll()
      .data(links)
      .join('g')
      .style('mix-blend-mode', 'multiply');

    link.append('path')
      .attr('d', (d) => path(d))
      .attr('stroke', (d) => color(d.target.category))
      .attr('stroke-width', (d) => Math.max(1, d.width));

    link.append('title')
      .text((d) => `${d.source.name} → ${d.target.name}\n${format(d.value)} USD`);

    svg.append('g')
      .selectAll()
      .data(nodes)
      .join('text')
      .attr('x', (d) => (d.x0 < width / 2 ? d.x1 + 6 : d.x0 - 6))
      .attr('y', (d) => (d.y1 + d.y0) / 2)
      .attr('dy', '0.35em')
      .attr('text-anchor', (d) => (d.x0 < width / 2 ? 'start' : 'end'))
      .text((d) => d.name);
  }

  useEffect(() => {
    if (htmlDivElementRef.current?.children.length === 0) {
      drawSankey();
      window.addEventListener('resize', drawSankey);
    }

    return () => {
      window.removeEventListener('resize', drawSankey);
    };
  }, [htmlDivElementRef.current, sankeyNodes, sankeyLinks]);

  return { drawSankey };
};
