import { tailwindMerge } from '@air/tailwind-variants';
import { observeElementOffset, type ScrollToOptions, useVirtualizer, type VirtualItem } from '@tanstack/react-virtual';
import { type ComponentPropsWithoutRef, memo, type ReactNode, useEffect, useRef, useState } from 'react';

export type TreeListProps<Data> = ComponentPropsWithoutRef<'div'> & {
  getEstimateSize: (index: number) => number;
  initialIndex?: number;
  items: Array<Data>;
  renderListItem: (
    item: {
      data: Data;
      scrollToIndex: (index: number, options: ScrollToOptions) => void;
    } & Pick<VirtualItem, 'index'>,
  ) => ReactNode;
  /**
   * The element that will be scrolled. This is used to calculate the scroll offset.
   * This element should be set using `setState` ensuring that the component is
   * updated properly when conditional rendering. (Ex. Desktop -> Mobile)
   */
  scrollElement: HTMLDivElement;
};

const _TreeList = <Data,>({
  className,
  getEstimateSize,
  initialIndex,
  items,
  renderListItem,
  scrollElement,
  ...restOfProps
}: TreeListProps<Data>) => {
  const [hasScrolled, setHasScrolled] = useState(!initialIndex);
  const listRef = useRef<HTMLDivElement | null>(null);

  const virtualizer = useVirtualizer({
    count: items.length,
    estimateSize: getEstimateSize,
    observeElementOffset,
    overscan: 5,
    scrollMargin: listRef.current?.offsetTop ?? 0,
    getScrollElement: () => scrollElement,
  });

  const virtualItems = virtualizer.getVirtualItems();

  /**
   * Scroll to the initial index if it exists.
   */
  useEffect(() => {
    if (initialIndex !== undefined && initialIndex > -1) {
      /**
       * We need to check if virtualItems exists to prevent the incorrect scroll position
       */
      if (virtualItems.length && !hasScrolled) {
        virtualizer.scrollToIndex(initialIndex, {
          behavior: 'smooth',
        });
        setHasScrolled(true);
      }
    }
  }, [hasScrolled, initialIndex, virtualItems, virtualizer]);

  return (
    <div
      className={tailwindMerge('relative w-full', className)}
      data-testid="TREE_LIST"
      ref={listRef}
      style={{
        height: virtualizer.getTotalSize(),
      }}
      {...restOfProps}
    >
      <div
        className="absolute left-0 top-0 flex w-full flex-col"
        style={{
          transform: `translateY(${(virtualItems[0]?.start ?? 0) - virtualizer.options.scrollMargin}px)`,
        }}
      >
        {virtualItems.map((item) => (
          <div key={item.key} data-index={item.index} ref={virtualizer.measureElement}>
            {renderListItem({
              data: items[item.index],
              index: item.index,
              scrollToIndex: virtualizer.scrollToIndex,
            })}
          </div>
        ))}
      </div>
    </div>
  );
};

export const TreeList = memo(_TreeList) as typeof _TreeList;
