"use client";

import { EvalDataPoint } from "@shared/dto/eval";

/**
 * Compute aggregated stats to display when no item is selected.
 */
function computeStats(evalDataset: EvalDataPoint[]) {
  const groupStats: Record<
    string,
    {
      subGroups: Record<string, { total: number; trueCount: number }>;
      total: number;
      trueCount: number;
    }
  > = {};

  const tagStats: Record<string, { total: number; trueCount: number }> = {};

  // Populate the stats
  for (const data of evalDataset) {
    const { group, subGroup, tags, scoreBoolean } = data;

    // Group-level
    if (!groupStats[group]) {
      groupStats[group] = {
        subGroups: {},
        total: 0,
        trueCount: 0,
      };
    }
    groupStats[group].total += 1;
    if (scoreBoolean) {
      groupStats[group].trueCount += 1;
    }

    // SubGroup-level
    if (!groupStats[group].subGroups[subGroup]) {
      groupStats[group].subGroups[subGroup] = { total: 0, trueCount: 0 };
    }
    groupStats[group].subGroups[subGroup].total += 1;
    if (scoreBoolean) {
      groupStats[group].subGroups[subGroup].trueCount += 1;
    }

    // Tags
    for (const tag of tags) {
      if (!tagStats[tag]) {
        tagStats[tag] = { total: 0, trueCount: 0 };
      }
      tagStats[tag].total += 1;
      if (scoreBoolean) {
        tagStats[tag].trueCount += 1;
      }
    }
  }

  return { groupStats, tagStats };
}

// Conditionally color the text based on whether the count is zero:
function getPassedClass(count: number) {
  return count === 0 ? "text-neutral-500" : "text-green-600";
}

function getFailedClass(count: number) {
  return count === 0 ? "text-neutral-500" : "text-red-600";
}

export function DatasetStats({ evalDataset }: { evalDataset: EvalDataPoint[] }) {
  const { groupStats, tagStats } = computeStats(evalDataset);

  // Sort groups by "most fails" descending
  const sortedGroups = Object.entries(groupStats).sort(([, aData], [, bData]) => {
    const aFails = aData.total - aData.trueCount;
    const bFails = bData.total - bData.trueCount;
    return bFails - aFails;
  });

  // Sort tags by "most fails" descending
  const sortedTags = Object.entries(tagStats).sort(([, aData], [, bData]) => {
    const aFails = aData.total - aData.trueCount;
    const bFails = bData.total - bData.trueCount;
    return bFails - aFails;
  });

  return (
    <div className="flex flex-col gap-4">
      {/* Groups & Subgroups Stats */}
      <div>
        <h2 className="text-lg font-medium mb-2">Total Eval Scores</h2>
        <div className="space-y-3">
          {sortedGroups.map(([groupName, groupData]) => {
            const passed = groupData.trueCount;
            const failed = groupData.total - groupData.trueCount;
            const ratio = `${passed} / ${groupData.total}`;

            // Sort subgroups by "most fails" descending
            const sortedSubGroups = Object.entries(groupData.subGroups).sort(
              ([, aData], [, bData]) => {
                const aFails = aData.total - aData.trueCount;
                const bFails = bData.total - bData.trueCount;
                return bFails - aFails;
              }
            );

            return (
              <div key={groupName} className="border p-2 rounded">
                <div className="font-medium">
                  {groupName}{" "}
                  <span className={getPassedClass(passed)}>
                    {passed} passed
                  </span>
                  ,{" "}
                  <span className={getFailedClass(failed)}>
                    {failed} failed
                  </span>{" "}
                  ({ratio})
                </div>

                {/* Sub-groups */}
                <div className="ml-4 mt-1 space-y-1">
                  {sortedSubGroups.map(([subGroupName, subGroupData]) => {
                    const subPassed = subGroupData.trueCount;
                    const subFailed =
                      subGroupData.total - subGroupData.trueCount;
                    const subRatio = `${subPassed} / ${subGroupData.total}`;
                    return (
                      <div key={subGroupName}>
                        {subGroupName}{" "}
                        <span className={getPassedClass(subPassed)}>
                          {subPassed} passed
                        </span>
                        ,{" "}
                        <span className={getFailedClass(subFailed)}>
                          {subFailed} failed
                        </span>{" "}
                        ({subRatio})
                      </div>
                    );
                  })}
                </div>
              </div>
            );
          })}
        </div>
      </div>

      {/* Tags Stats */}
      <div>
        <h2 className="text-lg font-medium mb-2">Tags Stats</h2>
        <div className="space-y-1">
          {sortedTags.map(([tagName, tagData]) => {
            const passed = tagData.trueCount;
            const failed = tagData.total - tagData.trueCount;
            const ratio = `${passed} / ${tagData.total}`;
            return (
              <div key={tagName}>
                <span className="font-medium">{tagName}:</span>{" "}
                <span className={getPassedClass(passed)}>
                  {passed} passed
                </span>
                ,{" "}
                <span className={getFailedClass(failed)}>
                  {failed} failed
                </span>{" "}
                ({ratio})
              </div>
            );
          })}
        </div>
      </div>
    </div>
  );
}
