import React, { ReactNode, useCallback, useState } from 'react';
import {
  Box,
  SortDirection,
  SxProps,
  TableCell,
  TableSortLabel,
  Theme,
} from '@mui/material';

export type Comparator<T> = (a: T, b: T) => number;
export type SortableTableColumn<T> = {
  name: string;
  heading?: ReactNode;
  sortable?: boolean;
  comparator?: Comparator<T>;
  sx?: SxProps<Theme>;
  rightIcons?: ReactNode;
};

export type TableSorter<T> = (input: Array<T>) => Array<T>;

export type TableSortState = {
  colName: string | null;
  direction: SortDirection;
};

export const initialSortState: TableSortState = {
  colName: null,
  direction: false,
};

export function useSorter<T>({
  columns,
}: {
  columns: SortableTableColumn<T>[];
}) {
  const [sort, setSort] = useState<TableSortState>(initialSortState);
  const { colName, direction } = sort;
  const sorter = useCallback<TableSorter<T>>(
    (refs) => {
      const comparator =
        colName !== null &&
        columns.find((col) => col.name === colName)?.comparator;
      if (!comparator) {
        return refs;
      }

      return [...refs].sort((a, b) => {
        return comparator(a, b) * (direction == 'asc' ? 1 : -1);
      });
    },
    [columns, colName, direction]
  );

  return { sort, setSort, sorter };
}

export function SortableColumns<T>({
  columns,
  sort,
  setSort,
}: {
  columns: SortableTableColumn<T>[];
  sort: TableSortState;
  setSort: React.Dispatch<React.SetStateAction<TableSortState>>;
}) {
  return (
    <>
      {columns.map((col) => (
        <TableCell key={col.name} sx={col.sx}>
          <Box sx={{ display: 'flex' }}>
            {col.sortable ? (
              <TableSortLabel
                direction={
                  (sort.colName === col.name && sort.direction) || undefined
                }
                active={sort.colName === col.name}
                onClick={() =>
                  setSort((prevSort) => {
                    // moving to this column, start with asc
                    if (prevSort.colName !== col.name) {
                      return { colName: col.name, direction: 'asc' };
                    }

                    // after asc, go to desc
                    if (prevSort.direction === 'asc') {
                      return {
                        colName: col.name,
                        direction: 'desc',
                      };
                    }

                    // after desc, go to initial unsorted state
                    return initialSortState;
                  })
                }
                // remove focus when moving the mouse away from the element. otherwise, after it's
                // been clicked, it retains focus and therefore the greyed focus/hover styling
                onMouseLeave={(e) => e.currentTarget.blur()}
              >
                {col.heading ?? col.name}
              </TableSortLabel>
            ) : (
              col.heading ?? col.name
            )}
            <Box sx={{ marginLeft: 'auto' }}>{col.rightIcons}</Box>
          </Box>
        </TableCell>
      ))}
    </>
  );
}

export function fieldComparator<ObjectT, FieldT>(
  extractor: (any: ObjectT) => FieldT,
  comparator: Comparator<FieldT>
): Comparator<ObjectT> {
  return (a: ObjectT, b: ObjectT) => comparator(extractor(a), extractor(b));
}

export const numberComparator = (a: number, b: number) => a - b;
export const stringComparator = (a: string, b: string) => a.localeCompare(b);
export function nullsLast<T>(
  comparator: Comparator<T>
): Comparator<T | null | undefined> {
  return (a: T | null | undefined, b: T | null | undefined) => {
    if (a === null || a === undefined) {
      return 1;
    }
    if (b === null || b === undefined) {
      return -1;
    }
    return comparator(a, b);
  };
}
