import { AccountPlan } from '@prisma/client';
import { ANALYSIS_TYPES, MODELS } from '@/constants';
import { ModerationHandler, ModerationResult } from '../../../../types';
import z from 'zod';
import { shouldFlagResult } from '../../flagging';

import { Comprehend } from '@aws-sdk/client-comprehend';

const comprehendClient = new Comprehend({
  credentials: {
    accessKeyId: process.env.AWS_ACCESS_KEY_ID as string,
    secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY as string,
  },
  region: process.env.AWS_REGION || 'eu-central-1',
});

export enum PII_LABEL {
  'EMAIL' = 'EMAIL',
  'ADDRESS' = 'ADDRESS',
  'NAME' = 'NAME',
  'PHONE' = 'PHONE',
  'SSN' = 'SSN',
  'URL' = 'URL',
  'PASSPORT_NUMBER' = 'PASSPORT_NUMBER',
  'AGE' = 'AGE',
  'PASSWORD' = 'PASSWORD',
  'DRIVER_ID' = 'DRIVER_ID',
  'IP_ADDRESS' = 'IP_ADDRESS',
  'CREDIT_CARD_NUMBER' = 'CREDIT_DEBIT_NUMBER',
}

export const PII_LABEL_DECSCRIPTIONS = {
  EMAIL: 'Email address',
  ADDRESS: 'Street address etc.',
  NAME: 'Name of person',
  PHONE: 'Phone number',
  SSN: 'Social security number',
  URL: 'Any type of URL',
  PASSPORT_NUMBER: 'Passport number',
  AGE: 'Age',
  PASSWORD: 'Password',
  DRIVER_ID: 'Driver ID',
  IP_ADDRESS: 'IP address',
  CREDIT_CARD_NUMBER: 'Credit card number',
  NEUTRAL: 'Nothing found',
};

const moderationHandler: ModerationHandler =
  (result: ModerationResult, filter, plan: AccountPlan, cache) =>
  async (value) => {
    try {
      const {
        shouldFlag = true,
        flagThreshold = 0.5,
        enabled,
      } = filter[ANALYSIS_TYPES.PII] || {};

      if (enabled) {
        const attributes = filter[ANALYSIS_TYPES.PII].components;

        const awsResult = await comprehendClient.containsPiiEntities({
          LanguageCode: 'en', // only english supported
          Text: value,
        });

        let label_scores = Object.keys(PII_LABEL)
          .filter((k) => !attributes?.length || attributes.includes(k))
          .reduce(
            (acc, label) => {
              acc[label] =
                awsResult?.Labels?.find((ent) => ent.Name === PII_LABEL[label])
                  ?.Score || 0;
              return acc;
            },
            {} as { [key: string]: number },
          );

        const max_score = Math.max(...Object.values(label_scores));
        label_scores['NEUTRAL'] = 1 - (max_score || 0);

        const sortedLabels = Object.entries(label_scores)
          .sort((a, b) => b[1] - a[1])
          .reduce((acc, val) => ({ ...acc, [val[0]]: val[1] }), {});

        const label = Object.keys(sortedLabels)[0];
        const score = sortedLabels[label];

        result[ANALYSIS_TYPES.PII] = {
          label,
          score,
          label_scores,
        };

        if (shouldFlagResult(filter[ANALYSIS_TYPES.PII], label_scores)) {
          result.flagged = true;
        }

        const quotaUsage = Math.ceil(
          value.length / (MODELS[ANALYSIS_TYPES.PII]?.maxChars || 1000),
        );

        if (plan.sumModelUsageEnabled) {
          result.request.base_quota_usage = Math.max(
            quotaUsage,
            result.request.base_quota_usage,
          );
        } else {
          result.request.quota_usage = Math.max(
            quotaUsage,
            result.request.quota_usage,
          );
        }
      }
    } catch (error) {
      result[ANALYSIS_TYPES.PII] = {
        error: error.message || 'error',
      };
    }
  };

const piiAnalyzer = {
  moderationHandler,
};

export default piiAnalyzer;
