import { GridOptions, RowDragEvent, RowNode } from 'ag-grid-community'
import { RefObject, useCallback, useState } from 'react'
import { TreeRow } from '../../model'
import { AUTO_COLUMN_ID } from '../../../BulkSheet/const'
import { getParentUuid } from '../../lib/tree'

// Set hovered information to context. Change styles using cellClassRules of colDefs
export const useDragTreeStyle = <R extends TreeRow>(
  ref: RefObject<HTMLDivElement>,
  acceptChild: (parent: R, child: R) => boolean,
  gridOptions: GridOptions
) => {
  const [topLevelNodes, setTopLevelNodes] = useState<RowNode[]>([])

  const onRowDragEnter = useCallback((event: RowDragEvent) => {
    const ids = event.nodes.map(v => v.id)
    setTopLevelNodes(event.nodes.filter(n => !ids.includes(n.parent?.id)))
  }, [])

  const onAutoColumn = useCallback((event: RowDragEvent): boolean => {
    const autoGroupColumn = event.columnApi.getColumn(AUTO_COLUMN_ID)
    if (!autoGroupColumn) return false
    const marginLeft = ref.current?.offsetLeft || 0
    return (
      (autoGroupColumn.getLeft() || 0) + marginLeft < event.event.x &&
      (autoGroupColumn.getRight() || 0) + marginLeft > event.event.x
    )
  }, [])

  const refreshRows = useCallback(
    (params: { draggableNodeId?: string; onTree?: boolean }) => {
      gridOptions.context = { ...gridOptions.context, ...params }
      gridOptions.api?.refreshCells()
    },
    [topLevelNodes]
  )

  const onRowDragMove = useCallback(
    (event: RowDragEvent) => {
      const overNode = event.overNode
      if (!overNode || !overNode.id) return

      // Can move to same level
      const onTree = onAutoColumn(event)
      if (
        !onTree &&
        topLevelNodes.every(
          v =>
            v.level === overNode.level &&
            v.id !== overNode.id &&
            getParentUuid(v.data) === getParentUuid(overNode.data)
        )
      ) {
        refreshRows({ draggableNodeId: overNode.id, onTree: false })
        return
      }

      // Can move to another parent
      const moveToDecendants = topLevelNodes.some(movingRow =>
        overNode.data.treeValue.includes(movingRow.data.uuid)
      )
      if (
        onTree &&
        topLevelNodes.every(v => acceptChild(overNode.data, v.data)) &&
        !topLevelNodes.some(v => v.id === overNode.id) &&
        !topLevelNodes.some(v => getParentUuid(v.data) === overNode.id) &&
        !moveToDecendants
      ) {
        refreshRows({ draggableNodeId: overNode.id, onTree: true })
        return
      }

      // Clear
      refreshRows({ draggableNodeId: undefined, onTree: false })
    },
    [topLevelNodes]
  )

  const refreshDragStyle = useCallback(() => {
    refreshRows({ draggableNodeId: undefined, onTree: false })
  }, [])

  return { onRowDragEnter, onRowDragMove, refreshDragStyle }
}
