import React, { FC, PropsWithChildren, useEffect, useRef } from 'react'

interface Props {
  className?: string
}

const TabTrap: FC<PropsWithChildren<Props>> = ({ children, className }) => {
  const rootRef = useRef<HTMLDivElement | null>(null)
  const tabStartRef = useRef<HTMLButtonElement | null>(null)
  const tabStopRef = useRef<HTMLButtonElement | null>(null)

  useEffect(() => {
    if (!tabStartRef.current || !tabStopRef.current || !rootRef.current) return

    // Focus initially on the first available item.
    setTimeout(() => {
      const tabbable = [
        ...(rootRef.current?.querySelectorAll('input, textarea, [href], select, button, [tabindex]') as any),
      ]
        // Only grab those that are tabbable
        .filter((t) => t.tabIndex > -1)
        // Put them in order.
        .sort((a, b) => b.tabIndex - a.tabIndex || 1)

      if (tabbable.length > 0) {
        // Focus the last one.
        const toFocus = tabbable[0]
        toFocus.focus()
      }
    }, 10)

    const tabStart = tabStartRef.current
    const tabStop = tabStopRef.current
    const fnTabStart = () => {
      const tabbable = [
        ...(rootRef.current?.querySelectorAll('input, textarea, [href], select, button, [tabindex]') as any),
      ]
        // Only grab those that are tabbable
        .filter((t) => t.tabIndex > -1)
        // Put them in order.
        .sort((a, b) => b.tabIndex - a.tabIndex || 1)

      if (tabbable.length > 0) {
        // Focus the last one.
        const toFocus = tabbable[tabbable.length - 1]
        toFocus.focus()
      }
    }

    const fnTabStop = () => {
      const tabbable = [
        ...(rootRef.current?.querySelectorAll('input, textarea, [href], select, button, [tabindex]') as any),
      ]
        // Only grab those that are tabbable
        .filter((t) => t.tabIndex > -1)
        // Put them in order.
        .sort((a, b) => b.tabIndex - a.tabIndex || 1)

      if (tabbable.length > 0) {
        // Focus the first one.
        const toFocus = tabbable[0]
        toFocus.focus()
      }
    }

    tabStart.addEventListener('focus', fnTabStart)
    tabStop.addEventListener('focus', fnTabStop)

    return () => {
      tabStart.removeEventListener('focus', fnTabStart)
      tabStop.removeEventListener('focus', fnTabStop)
    }
  }, [tabStartRef, tabStopRef, rootRef])

  return (
    <>
      <button
        type="button"
        style={{
          height: 0,
          width: 0,
          display: 'block',
          opacity: 0,
          padding: 0,
          margin: 0,
          border: 'none',
        }}
        ref={tabStartRef}
      />
      <div ref={rootRef} className={className}>
        {children}
      </div>
      <button
        type="button"
        style={{
          height: 0,
          width: 0,
          display: 'block',
          opacity: 0,
          padding: 0,
          margin: 0,
          border: 'none',
        }}
        ref={tabStopRef}
      />
    </>
  )
}

export default TabTrap
