import { useCallback } from 'react';
import { useSelector } from 'react-redux';
import { getModels } from '../slices/models';
import { format, parseISO } from 'date-fns';
import groupBy from 'lodash/groupBy';
import sumBy from 'lodash/sumBy';
import { getUsage, ModelUsage } from '../slices/usage';

const useModelUsage = (
  type: string,
  usageFn: (item: ModelUsage) => number,
  columnArr: string[]
) => {
  const models = useSelector(getModels);
  const usageData = useSelector(getUsage);
  const mapModelToName = useCallback(
    (name: string) =>
      models.find((model) => model.model === name)?.name || name,
    [models]
  );

  const groupedUsage = Object.entries(
    groupBy(usageData, (x) => x.model_name)
  ).map(([name, usage]) => ({
    id: usage?.[0].model_id,
    name,
    type: usage?.[0].model_type,
    usage,
  }));
  const tableData = groupedUsage
    .filter((usageDatum) =>
      type === 'llm'
        ? ['llm', 'vlm'].includes(usageDatum.type)
        : usageDatum.type === type
    )
    .filter((usageDatum) => sumBy(usageDatum.usage, usageFn))
    .map((usageDatum) => ({
      ...usageDatum,
      ...columnArr.reduce(
        (acc, key) => ({
          ...acc,
          [key]: sumBy(usageDatum.usage, (item: any) => item[key]),
        }),
        {}
      ),
    }));

  const modelsWithUsage = tableData.map((usageDatum) => usageDatum.name);
  const usageByDate = Object.entries(groupBy(usageData, (x) => x.date));
  const chartData = usageByDate.map(([name, usage]) => {
    const filteredUsage = usage.filter((u) =>
      modelsWithUsage.includes(u.model_name)
    );
    const data = filteredUsage.reduce(
      (acc, u) => ({
        ...acc,
        [u.model_name]: usageFn(u),
      }),
      {}
    );
    const date = parseISO(name);
    return {
      date: format(date, 'MMM dd'),
      type,
      ...data,
    };
  });

  return { mapModelToName, chartData, modelsWithUsage, tableData };
};

export default useModelUsage;
