import * as d3 from 'd3'
import { useCallback, useEffect } from 'react'
import { createRoot } from 'react-dom/client'
import { Chart, ChartProps, ID_CHART_DIV } from '../chart'

d3.selection.prototype.bringElementAsTopLayer = function () {
  return this.each(() => {
    this.parentNode.appendChild(this)
  })
}

const style = {
  chart: {
    padding: 5,
  },
  controlPanel: {
    width: 250,
  },
  g: {
    margin: { right: 30, bottom: 60 },
  },
  yAxis: {
    width: 40,
    margin: { top: 0, left: 50, bottom: 60 },
  },
  xAxis: {
    height: 30,
  },
  tickLabel: {
    width: 35,
    marginHorizontal: 3,
  },
}

const CHART_MAIN = 'chartMain'
const DEFAULT_BASE_POS_RATIO: number = 1 / 3
export const CHART_MAIN_TOP: number = 12

export type TwoDimChartScale = {
  width?: number
  xValues: number[]
  xWidth: number
  yMin: number
  yMax: number
}

export type ValueFormatter = (value: number) => string

type Props = ChartProps & {
  graph: TwoDimGraphPlotter[]
  xFormatter?: ValueFormatter
  yFormatter?: ValueFormatter
  scale: TwoDimChartScale
  baseXValue?: number
}

export const TwoDimChart = (props: Props) => {
  const {
    graph,
    scale: { width, xValues, xWidth, yMin, yMax },
    xFormatter = (v: number) => `${v}`,
    yFormatter = (v: number) => `${v}`,
    isLoading,
    baseXValue,
  } = props
  const xMin = Math.min(...xValues)
  const xMax = Math.max(...xValues)
  const xDelta = xValues.length > 1 ? xValues[1] - xValues[0] : 0
  const plotQuadrant = useCallback(() => {
    const xNum = xValues.length
    d3.selectAll(`#${ID_CHART_DIV} > div`).remove()
    const root = d3
      .select<HTMLElement, any>(`#${ID_CHART_DIV}`)
      .style('width', `${width ? `${width}px` : '0px'}`)
      .style('height', `100%`)
      .style('padding', `${style.chart.padding}px`)
      .style('display', 'flex')
    const rootWidth = root.node()!.offsetWidth
    const height = root.node()!.offsetHeight - 30 // adjust
    const innerWidth =
      Math.max(rootWidth - style.yAxis.width, xWidth * xNum) - 10
    const innerHeight = height - style.xAxis.height
    const yAxisSvg = root
      .append('div')
      .style('width', `${style.yAxis.width}px`)
      .style('height', `100%`)
      .style('top', `${style.chart.padding}px`)
      .append('svg')
      .style('width', `${style.yAxis.width}px`)
      .style('height', `${height}px`)
      .append('g')

    const chart = root
      .append('div')
      .attr('id', CHART_MAIN)
      .style('overflow-x', 'scroll')
      .style('overflow-y', 'hidden')
      .style('width', `100%`)
      .style('height', `100%`)
      .style('position', 'relative')
      .on('click', (e, d) => {
        const tooltipExists = !d3
          .select(`#${CHART_MAIN} > #${ID_TOOLTIP}`)
          .empty()
        const tooltipClicked = e.target.closest(`#${ID_TOOLTIP}`)
        const tooltipTriggerClicked = e.target.classList.contains(
          CLASS_TOOLTIP_TRIGGER_ELEMENT
        )
        if (tooltipExists && !tooltipClicked && !tooltipTriggerClicked) {
          removeTooltip()
          removeSelectXLine()
        }
      })
      .append('svg')
      .style('width', `${innerWidth}px`)
      .style('height', `${height}px`)
      .append('g')
    const xScale = d3
      .scaleLinear()
      .domain([xMin - xDelta / 5, xMax + xDelta])
      .range([0, innerWidth])
    const xAxis = d3
      .axisBottom(xScale)
      .tickValues(xValues)
      .tickFormat(v => xFormatter(v.valueOf()))
    // calculate
    const translateX = 4
    const translateY = -8
    const rotateDeg = 0
    chart
      .append('g')
      .attr('class', 'x axis')
      .attr('transform', `translate(0,${innerHeight})`)
      .call(xAxis)
      .selectAll('text')
      .attr('y', 16)
      .attr('x', 16)
      .style(
        'transform',
        `translate(${translateX}px, ${translateY}px) rotate(${rotateDeg}deg)`
      )
    const yScale =
      yMax === 0
        ? d3.scaleLinear().domain([0, 1]).range([innerHeight, 0])
        : d3
            .scaleLinear()
            .domain([0, yMax * 1.2])
            .range([innerHeight, 0])
    const tickCount = 10
    const yAxis = d3
      .axisLeft(yScale)
      .ticks(tickCount)
      .tickFormat(() => '')
      .tickSize(-innerWidth)
    chart
      .append('g')
      .attr('class', 'chart-grid')
      .attr('transform', `translate(0,0)`)
      .call(yAxis)
    const yAxisTick = d3
      .axisLeft(yScale)
      .ticks(tickCount)
      .tickFormat(v => yFormatter(v.valueOf()))
    yAxisSvg
      .append('g')
      .attr('class', 'y axis')
      .attr('transform', `translate(${style.yAxis.width},0)`)
      .call(yAxisTick)
    return { chart, xScale, yScale, innerHeight }
  }, [width, xValues, xDelta, xWidth, yMin, yMax])
  const plot = useCallback(() => {
    const ctx = plotQuadrant()
    graph.forEach(g => g.plot(ctx))

    const chartNode = d3.select<Element, any>(`#${CHART_MAIN}`).node()
    if (chartNode && chartNode.clientWidth && baseXValue) {
      const orgX = ctx.xScale(baseXValue)
      const dstX = chartNode.clientWidth * DEFAULT_BASE_POS_RATIO
      chartNode.scrollLeft = orgX - dstX
    }
  }, [graph, width])

  useEffect(() => {
    if (!isLoading) {
      plot()
    }
  }, [graph, width])

  return <Chart isLoading={isLoading} />
}

export interface TwoDimChartContext {
  chart: d3.Selection<SVGGElement, unknown, HTMLElement, any>
  xScale: d3.ScaleLinear<number, number>
  yScale: d3.ScaleLinear<number, number>
  innerHeight: number
}

export abstract class TwoDimGraphPlotter {
  abstract plot(ctx: TwoDimChartContext): void
}

export type PortalOrigin = 'top' | 'bottom' | 'left' | 'right'

export const addPortal = (
  id: string,
  content: JSX.Element,
  originX: PortalOrigin = 'left',
  originY: PortalOrigin = 'bottom',
  dx: number = 0,
  dy: number = 0,
  notStickOutScreen: boolean = true,
  onPortalRendered?: (
    root: HTMLElement | null,
    container: HTMLDivElement,
    content: Element | null
  ) => void
): d3.Selection<HTMLDivElement, any, Element, any> => {
  const root = d3.select<Element, any>(`#${CHART_MAIN}`)
  const scrollTop = root.node()?.scrollTop || 0
  const scrollLeft = root.node()?.scrollLeft || 0
  const clientWidth = root.node()?.clientWidth || 0
  const portal = root
    .append('div')
    .attr('id', `${id}`)
    .attr('class', 'portal')
    .style('z-index', 100000)
    .style('display', 'block')
  const node = portal.node()!
  const portalRoot = createRoot(node)
  portalRoot.render(content)
  setTimeout(() => {
    if (onPortalRendered) {
      const rootElement = node.parentElement
      const contentElement = node.firstElementChild
      onPortalRendered(rootElement, node, contentElement)
    }

    const xDelta = calculateXDelta(
      node,
      originX,
      dx,
      scrollLeft,
      clientWidth,
      notStickOutScreen
    )
    const yDelta = calculateYDelta(
      node,
      originY,
      dy,
      scrollTop,
      notStickOutScreen
    )
    portal.style('top', `${yDelta}px`)
    portal.style('left', `${xDelta}px`)
    return portal
  }, 100)
  return portal
}

const calculateXDelta = (
  node: HTMLDivElement,
  origin: PortalOrigin,
  dx: number,
  scrollLeft: number,
  rootClientWidth: number,
  notStickOutScreen: boolean = true
): number => {
  let xDelta: number = 0
  switch (origin) {
    case 'top':
    case 'bottom':
      xDelta = (-1 * (node.offsetWidth || 0)) / 2
      break
    case 'left':
      xDelta = 0
      break
    case 'right':
      xDelta = -1 * node.offsetWidth || 0
      break
  }
  if (notStickOutScreen) {
    xDelta = Math.max(scrollLeft, dx + scrollLeft + xDelta)
    const portalWidth = node.offsetWidth
    if (scrollLeft + rootClientWidth < xDelta + portalWidth) {
      xDelta = scrollLeft + rootClientWidth - portalWidth
    }
  } else {
    xDelta = dx + scrollLeft + xDelta
  }
  return xDelta
}

const calculateYDelta = (
  node: HTMLDivElement,
  origin: PortalOrigin,
  dy: number,
  scrollTop: number,
  notStickOutScreen: boolean = true
): number => {
  let yDelta: number = 0
  switch (origin) {
    case 'top':
    case 'left':
    case 'right':
      yDelta = 0
      break
    case 'bottom':
      yDelta = (node?.offsetHeight || 0) * -1.1
      break
  }
  return notStickOutScreen
    ? Math.max(scrollTop, dy + scrollTop + yDelta)
    : dy + scrollTop + yDelta
}

export const removePortal = (id: string) => {
  const target = d3.select(`#${CHART_MAIN} > #${id}`)
  target?.remove()
}

const ID_TOOLTIP = 'tooltip'
const CLASS_TOOLTIP_TRIGGER_ELEMENT: string = 'tooltip_trigger'

const tooltipExists = (): boolean => {
  return !!d3.select(`#${CHART_MAIN} > #${ID_TOOLTIP}`)
}
export const addTooltipByEvent = (
  content: string | JSX.Element,
  event: any,
  originX: PortalOrigin = 'left',
  originY: PortalOrigin = 'bottom',
  dx: number = 0,
  dy: number = 0
) => {
  const root = d3.select<Element, any>(`#${CHART_MAIN}`)
  const pos = d3.pointer(event, root.node())
  addTooltip(content, event.target, originX, originY, pos[0] + dx, pos[1] + dy)
}

export const addTooltip = (
  content: string | JSX.Element,
  rootNodeSelector: string,
  originX: PortalOrigin = 'left',
  originY: PortalOrigin = 'bottom',
  dx: number = 0,
  dy: number = 0
) => {
  if (tooltipExists()) {
    removeTooltip()
  }
  const tooltipContent = typeof content === 'string' ? <>{content}</> : content
  const tooltipTrigger = d3.select(rootNodeSelector)
  tooltipTrigger.classed(CLASS_TOOLTIP_TRIGGER_ELEMENT, true)
  addPortal(ID_TOOLTIP, tooltipContent, originX, originY, dx, dy)
}

export const removeTooltip = () => {
  removePortal(ID_TOOLTIP)
  const tooltipTrigger = d3.select(
    `#${CHART_MAIN} .${CLASS_TOOLTIP_TRIGGER_ELEMENT}`
  )
  if (tooltipTrigger && tooltipTrigger.node()) {
    tooltipTrigger.classed(CLASS_TOOLTIP_TRIGGER_ELEMENT, false)
  }
}

const ID_SELECT_X_LINE: string = 'selectedX'

export const addSelectXLine = (x: number, innerHeight: number) => {
  removeSelectXLine()
  d3.select<HTMLDivElement, any>(`#${CHART_MAIN} > svg > g`)
    .insert('line', ':first-child')
    .attr('id', ID_SELECT_X_LINE)
    .attr('x1', x)
    .attr('x2', x)
    .attr('y1', CHART_MAIN_TOP)
    .attr('y2', innerHeight)
    .attr('stroke', '#FFFF00')
    .attr('stroke-width', '20px')
    .style('opacity', 0.25)
}

export const removeSelectXLine = () => {
  const selectXLine = d3.select<HTMLDivElement, any>(`#${ID_SELECT_X_LINE}`)
  if (!selectXLine.empty()) {
    selectXLine.remove()
  }
}
